Merge branch 'main' into jzh

This commit is contained in:
JzoNg
2026-04-10 10:36:45 +08:00
20 changed files with 421 additions and 335 deletions

View File

@ -346,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource):
# @setup_required
# @login_required
# @account_initialization_required
# @get_rag_pipeline
# def post(self, pipeline: Pipeline, node_id: str):
# """
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
# if job_id == None:
# raise ValueError("missing job_id")
# datasource_type = args.get("datasource_type")
# if datasource_type == None:
# raise ValueError("missing datasource_type")
#
# rag_pipeline_service = RagPipelineService()
# result = rag_pipeline_service.run_datasource_workflow_node_status(
# pipeline=pipeline,
# node_id=node_id,
# job_id=job_id,
# account=current_user,
# datasource_type=datasource_type,
# is_published=True
# )
#
# return result
# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource):
# @setup_required
# @login_required
# @account_initialization_required
# @get_rag_pipeline
# def post(self, pipeline: Pipeline, node_id: str):
# """
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
# if job_id == None:
# raise ValueError("missing job_id")
# datasource_type = args.get("datasource_type")
# if datasource_type == None:
# raise ValueError("missing datasource_type")
#
# rag_pipeline_service = RagPipelineService()
# result = rag_pipeline_service.run_datasource_workflow_node_status(
# pipeline=pipeline,
# node_id=node_id,
# job_id=job_id,
# account=current_user,
# datasource_type=datasource_type,
# is_published=False
# )
#
# return result
#
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])

View File

@ -7,7 +7,8 @@ import logging
from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource):
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
selected_action_id=payload.action,
form_data=payload.inputs,
submission_user_id=current_user.id,
)

View File

@ -7,7 +7,8 @@ import logging
from datetime import datetime
from flask import Response, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
@ -23,6 +24,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
@ -112,10 +119,7 @@ class HumanInputFormApi(Resource):
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
ip_address = extract_remote_ip(request)
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
@ -135,8 +139,8 @@ class HumanInputFormApi(Resource):
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
selected_action_id=payload.action,
form_data=payload.inputs,
submission_end_user_id=None,
# submission_end_user_id=_end_user.id,
)

View File

@ -1,10 +1,10 @@
from typing import Literal, Optional
from typing import Any, Literal, Optional, TypedDict
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, field_validator
from core.datasource.entities.datasource_entities import DatasourceParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
class DatasourceApiEntity(BaseModel):
@ -20,6 +20,23 @@ class DatasourceApiEntity(BaseModel):
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
class DatasourceProviderApiEntityDict(TypedDict):
id: str
author: str
name: str
plugin_id: str | None
plugin_unique_identifier: str | None
description: I18nObjectDict
icon: str | dict
label: I18nObjectDict
type: str
team_credentials: dict | None
is_team_authorization: bool
allow_delete: bool
datasources: list[Any]
labels: list[str]
class DatasourceProviderApiEntity(BaseModel):
id: str
author: str
@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel):
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict:
def to_dict(self) -> DatasourceProviderApiEntityDict:
# -------------
# overwrite datasource parameter types for temp fix
datasources = jsonable_encoder(self.datasources)
@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel):
parameter["type"] = "files"
# -------------
return {
result: DatasourceProviderApiEntityDict = {
"id": self.id,
"author": self.author,
"name": self.name,
@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel):
"datasources": datasources,
"labels": self.labels,
}
return result

View File

@ -146,7 +146,7 @@ def discover_protected_resource_metadata(
return ProtectedResourceMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
except (RequestError, ValidationError, json.JSONDecodeError):
continue # Try next URL
return None
@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata(
return OAuthMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
except (RequestError, ValidationError, json.JSONDecodeError):
continue # Try next URL
return None
@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
else:
return False, ""
return False, ""
except RequestError:
except (RequestError, json.JSONDecodeError, IndexError):
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""

View File

@ -61,27 +61,28 @@ class TokenBufferMemory:
:param is_user_message: whether this is a user message
:return: PromptMessage
"""
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
app = self.conversation.app
if not app:
raise ValueError("App not found for conversation")
match self.conversation.mode:
case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
app = self.conversation.app
if not app:
raise ValueError("App not found for conversation")
if not message.workflow_run_id:
raise ValueError("Workflow run ID not found")
if not message.workflow_run_id:
raise ValueError("Workflow run ID not found")
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
case _:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.HIGH
if file_extra_config and app_record:

View File

@ -5,6 +5,7 @@ from typing import Any
from elasticsearch import Elasticsearch
from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -19,6 +20,16 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
hosts: list[str]
verify_certs: bool
ssl_show_warn: bool
request_timeout: int
retry_on_timeout: bool
max_retries: int
basic_auth: tuple[str, str]
def create_ssl_context() -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
raise ValueError("config HOSTS is required")
return values
def to_elasticsearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts.split(","),
"verify_certs": False,
"ssl_show_warn": False,
"request_timeout": 30000,
"retry_on_timeout": True,
"max_retries": 10,
}
def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
params = HuaweiElasticsearchParamsDict(
hosts=self.hosts.split(","),
verify_certs=False,
ssl_show_warn=False,
request_timeout=30000,
retry_on_timeout=True,
max_retries=10,
)
if self.username and self.password:
params["basic_auth"] = (self.username, self.password)
return params

View File

@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index"
class LindormOpenSearchParamsDict(TypedDict, total=False):
hosts: str | None
use_ssl: bool
pool_maxsize: int
timeout: int
http_auth: tuple[str, str]
class LindormVectorStoreConfig(BaseModel):
hosts: str | None
username: str | None = None
@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
raise ValueError("config PASSWORD is required")
return values
def to_opensearch_params(self) -> dict[str, Any]:
params: dict[str, Any] = {
"hosts": self.hosts,
"use_ssl": False,
"pool_maxsize": 128,
"timeout": 30,
}
def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
params = LindormOpenSearchParamsDict(
hosts=self.hosts,
use_ssl=False,
pool_maxsize=128,
timeout=30,
)
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params

View File

@ -6,6 +6,7 @@ from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config
from configs.middleware.vdb.opensearch_config import AuthMethod
@ -21,6 +22,20 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class _OpenSearchHostDict(TypedDict):
host: str
port: int
class OpenSearchParamsDict(TypedDict, total=False):
hosts: list[_OpenSearchHostDict]
use_ssl: bool
verify_certs: bool
connection_class: type
pool_maxsize: int
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
class OpenSearchConfig(BaseModel):
host: str
port: int
@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.verify_certs,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
def to_opensearch_params(self) -> OpenSearchParamsDict:
params = OpenSearchParamsDict(
hosts=[{"host": self.host, "port": self.port}],
use_ssl=self.secure,
verify_certs=self.verify_certs,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
)
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")

View File

@ -5,12 +5,30 @@ from typing import Any
import pytz # type: ignore[import-untyped]
from celery import Celery, Task
from celery.schedules import crontab
from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
def get_celery_ssl_options() -> dict[str, Any] | None:
class _CelerySentinelKwargsDict(TypedDict):
socket_timeout: float | None
password: str | None
class CelerySentinelTransportDict(TypedDict):
master_name: str | None
sentinel_kwargs: _CelerySentinelKwargsDict
class CelerySSLOptionsDict(TypedDict):
ssl_cert_reqs: int
ssl_ca_certs: str | None
ssl_certfile: str | None
ssl_keyfile: str | None
def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
"""Get SSL configuration for Celery broker/backend connections."""
# Only apply SSL if we're using Redis as broker/backend
if not dify_config.BROKER_USE_SSL:
@ -33,26 +51,24 @@ def get_celery_ssl_options() -> dict[str, Any] | None:
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
ssl_options = {
"ssl_cert_reqs": ssl_cert_reqs,
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
}
return ssl_options
return CelerySSLOptionsDict(
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS,
ssl_certfile=dify_config.REDIS_SSL_CERTFILE,
ssl_keyfile=dify_config.REDIS_SSL_KEYFILE,
)
def get_celery_broker_transport_options() -> dict[str, Any]:
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
if dify_config.CELERY_USE_SENTINEL:
return {
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
"sentinel_kwargs": {
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
"password": dify_config.CELERY_SENTINEL_PASSWORD,
},
}
return CelerySentinelTransportDict(
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
sentinel_kwargs=_CelerySentinelKwargsDict(
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
password=dify_config.CELERY_SENTINEL_PASSWORD,
),
)
return {}

View File

@ -674,28 +674,24 @@ class AppModelConfig(TypeBase):
def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else []
def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig:
return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled})
@property
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
return cast(
EnabledConfig,
json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer
else {"enabled": False},
)
return self._get_enabled_config(self.suggested_questions_after_answer)
@property
def speech_to_text_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
return self._get_enabled_config(self.speech_to_text)
@property
def text_to_speech_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
return self._get_enabled_config(self.text_to_speech)
@property
def retriever_resource_dict(self) -> EnabledConfig:
return cast(
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
)
return self._get_enabled_config(self.retriever_resource, default_enabled=True)
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
@ -722,7 +718,7 @@ class AppModelConfig(TypeBase):
@property
def more_like_this_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
return self._get_enabled_config(self.more_like_this)
@property
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
@ -902,7 +898,7 @@ class InstalledApp(TypeBase):
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class TrialApp(Base):
class TrialApp(TypeBase):
__tablename__ = "trial_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
@ -911,18 +907,26 @@ class TrialApp(Base):
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, default=gen_uuidv4_string)
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
)
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
@property
def app(self) -> App | None:
return db.session.scalar(select(App).where(App.id == self.app_id))
class AccountTrialAppRecord(Base):
class AccountTrialAppRecord(TypeBase):
__tablename__ = "account_trial_app_records"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
@ -930,11 +934,19 @@ class AccountTrialAppRecord(Base):
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, default=gen_uuidv4_string)
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
)
@property
def app(self) -> App | None:

View File

@ -66,12 +66,15 @@ def build_file_from_stored_mapping(
record_id = resolve_file_record_id(mapping)
transfer_method = FileTransferMethod.value_of(mapping["transfer_method"])
if transfer_method == FileTransferMethod.TOOL_FILE and record_id:
mapping["tool_file_id"] = record_id
elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id:
mapping["upload_file_id"] = record_id
elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id:
mapping["datasource_file_id"] = record_id
match transfer_method:
case FileTransferMethod.TOOL_FILE if record_id:
mapping["tool_file_id"] = record_id
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id:
mapping["upload_file_id"] = record_id
case FileTransferMethod.DATASOURCE_FILE if record_id:
mapping["datasource_file_id"] = record_id
case _:
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
remote_url = mapping.get("remote_url")

View File

@ -467,61 +467,67 @@ class AppDslService:
)
# Initialize app based on mode
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for workflow/advanced chat app")
match app_mode:
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for workflow/advanced chat app")
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj)
for obj in conversation_variables_list
]
workflow_service = WorkflowService()
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
]
workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}:
# Initialize model config
model_config = data.get("model_config")
if not model_config or not isinstance(model_config, dict):
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
app_model_config = AppModelConfig(
app_id=app.id, created_by=account.id, updated_by=account.id
).from_model_config_dict(cast(AppModelConfigDict, model_config))
app_model_config.id = str(uuid4())
app.app_model_config_id = app_model_config.id
workflow_service = WorkflowService()
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (
decrypted_id := self.decrypt_dataset_id(
encrypted_data=dataset_id, tenant_id=app.tenant_id
)
)
]
workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION:
# Initialize model config
model_config = data.get("model_config")
if not model_config or not isinstance(model_config, dict):
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
app_model_config = AppModelConfig(
app_id=app.id, created_by=account.id, updated_by=account.id
).from_model_config_dict(cast(AppModelConfigDict, model_config))
app_model_config.id = str(uuid4())
app.app_model_config_id = app_model_config.id
self._session.add(app_model_config)
app_model_config_was_updated.send(app, app_model_config=app_model_config)
else:
raise ValueError("Invalid app mode")
self._session.add(app_model_config)
app_model_config_was_updated.send(app, app_model_config=app_model_config)
case _:
raise ValueError("Invalid app mode")
return app
@classmethod

View File

@ -132,8 +132,8 @@ class FileService:
return file_size <= file_size_limit
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = self._session_maker(expire_on_commit=False).scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
if not upload_file:
raise NotFound("File not found")
@ -178,7 +178,7 @@ class FileService:
Return a short text preview extracted from a document file.
"""
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found")
@ -200,7 +200,7 @@ class FileService:
if not result:
raise NotFound("File not found or signature is invalid")
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -220,7 +220,7 @@ class FileService:
raise NotFound("File not found or signature is invalid")
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -231,7 +231,7 @@ class FileService:
def get_public_image_preview(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -247,7 +247,7 @@ class FileService:
def get_file_content(self, file_id: str) -> str:
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found")

View File

@ -1,3 +1,5 @@
from typing import Any, TypedDict
from sqlalchemy import select
from constants.languages import languages
@ -8,16 +10,43 @@ from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
from services.recommend_app.recommend_app_type import RecommendAppType
class RecommendedAppItemDict(TypedDict):
id: str
app: App | None
app_id: str
description: Any
copyright: Any
privacy_policy: Any
custom_disclaimer: str
category: str
position: int
is_listed: bool
class RecommendedAppsResultDict(TypedDict):
recommended_apps: list[RecommendedAppItemDict]
categories: list[str]
class RecommendedAppDetailDict(TypedDict):
id: str
name: str
icon: Any
icon_background: str | None
mode: str
export_data: str
class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
"""
Retrieval recommended app from database
"""
def get_recommended_apps_and_categories(self, language: str):
def get_recommended_apps_and_categories(self, language: str) -> RecommendedAppsResultDict:
result = self.fetch_recommended_apps_from_db(language)
return result
def get_recommend_app_detail(self, app_id: str):
def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None:
result = self.fetch_recommended_app_detail_from_db(app_id)
return result
@ -25,7 +54,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.DATABASE
@classmethod
def fetch_recommended_apps_from_db(cls, language: str):
def fetch_recommended_apps_from_db(cls, language: str) -> RecommendedAppsResultDict:
"""
Fetch recommended apps from db.
:param language: language
@ -41,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
).all()
categories = set()
recommended_apps_result = []
recommended_apps_result: list[RecommendedAppItemDict] = []
for recommended_app in recommended_apps:
app = recommended_app.app
if not app or not app.is_public:
@ -51,7 +80,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
if not site:
continue
recommended_app_result = {
recommended_app_result: RecommendedAppItemDict = {
"id": recommended_app.id,
"app": recommended_app.app,
"app_id": recommended_app.app_id,
@ -67,10 +96,10 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
categories.add(recommended_app.category)
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
return RecommendedAppsResultDict(recommended_apps=recommended_apps_result, categories=sorted(categories))
@classmethod
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None:
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> RecommendedAppDetailDict | None:
"""
Fetch recommended app detail from db.
:param app_id: App ID
@ -89,11 +118,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
if not app_model or not app_model.is_public:
return None
return {
"id": app_model.id,
"name": app_model.name,
"icon": app_model.icon,
"icon_background": app_model.icon_background,
"mode": app_model.mode,
"export_data": AppDslService.export_dsl(app_model=app_model),
}
return RecommendedAppDetailDict(
id=app_model.id,
name=app_model.name,
icon=app_model.icon,
icon_background=app_model.icon_background,
mode=app_model.mode,
export_data=AppDslService.export_dsl(app_model=app_model),
)

View File

@ -104,32 +104,32 @@ class WebhookService:
"""
with Session(db.engine) as session:
# Get webhook trigger
webhook_trigger = (
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first()
webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).limit(1)
)
if not webhook_trigger:
raise ValueError(f"Webhook not found: {webhook_id}")
if is_debug:
workflow = (
session.query(Workflow)
.filter(
workflow = session.scalar(
select(Workflow)
.where(
Workflow.app_id == webhook_trigger.app_id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.order_by(Workflow.created_at.desc())
.first()
.limit(1)
)
else:
# Check if the corresponding AppTrigger exists
app_trigger = (
session.query(AppTrigger)
.filter(
app_trigger = session.scalar(
select(AppTrigger)
.where(
AppTrigger.app_id == webhook_trigger.app_id,
AppTrigger.node_id == webhook_trigger.node_id,
AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK,
)
.first()
.limit(1)
)
if not app_trigger:
@ -146,14 +146,14 @@ class WebhookService:
raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}")
# Get workflow
workflow = (
session.query(Workflow)
.filter(
workflow = session.scalar(
select(Workflow)
.where(
Workflow.app_id == webhook_trigger.app_id,
Workflow.version != Workflow.VERSION_DRAFT,
)
.order_by(Workflow.created_at.desc())
.first()
.limit(1)
)
if not workflow:
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task # type: ignore
from sqlalchemy import select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
session.query(DatasetDocument)
.where(
dataset_documents = session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
for segment in segments:
@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
elif action == "update":
dataset_documents = (
session.query(DatasetDocument)
.where(
dataset_documents = session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
multimodal_documents = []
@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
else:

View File

@ -862,6 +862,15 @@ class TestAuthOrchestration:
result = discover_protected_resource_metadata(None, "https://api.example.com")
assert result is None
# JSONDecodeError (non-JSON 200 response)
mock_get.side_effect = None
bad_json_response = Mock()
bad_json_response.status_code = 200
bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_get.return_value = bad_json_response
result = discover_protected_resource_metadata(None, "https://api.example.com")
assert result is None
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_authorization_server_metadata(self, mock_get):
# Success
@ -892,6 +901,14 @@ class TestAuthOrchestration:
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
assert result is None
# JSONDecodeError (non-JSON 200 response)
bad_json_response = Mock()
bad_json_response.status_code = 200
bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_get.return_value = bad_json_response
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
assert result is None
def test_get_effective_scope(self):
prm = ProtectedResourceMetadata(
resource="https://api.example.com",
@ -997,6 +1014,24 @@ class TestAuthOrchestration:
supported, url = check_support_resource_discovery("https://api")
assert supported is False
# Case 6: JSONDecodeError (non-JSON 200 response)
mock_get.side_effect = None
bad_json_res = Mock()
bad_json_res.status_code = 200
bad_json_res.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_get.return_value = bad_json_res
supported, url = check_support_resource_discovery("https://api")
assert supported is False
assert url == ""
# Case 7: Empty authorization_servers array (IndexError)
empty_res = Mock()
empty_res.status_code = 200
empty_res.json.return_value = {"authorization_servers": []}
mock_get.return_value = empty_res
supported, url = check_support_resource_discovery("https://api")
assert supported is False
def test_discover_oauth_metadata(self):
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:

View File

@ -165,7 +165,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.key = "test_key"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.storage") as mock_storage:
mock_storage.load_once.return_value = b"test content"
@ -178,7 +178,7 @@ class TestFileService:
mock_storage.load_once.assert_called_once_with("test_key")
def test_get_file_base64_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"):
file_service.get_file_base64("non_existent")
@ -215,7 +215,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "pdf"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract:
mock_extract.return_value = "Extracted text content"
@ -227,7 +227,7 @@ class TestFileService:
assert result == "Extracted text content"
def test_get_file_preview_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"):
file_service.get_file_preview("non_existent")
@ -235,7 +235,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "exe"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with pytest.raises(UnsupportedFileTypeError):
file_service.get_file_preview("file_id")
@ -246,7 +246,7 @@ class TestFileService:
upload_file.extension = "jpg"
upload_file.mime_type = "image/jpeg"
upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with (
patch("services.file_service.file_helpers.verify_image_signature") as mock_verify,
@ -269,7 +269,7 @@ class TestFileService:
file_service.get_image_preview("file_id", "ts", "nonce", "sign")
def test_get_image_preview_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
mock_verify.return_value = True
with pytest.raises(NotFound, match="File not found or signature is invalid"):
@ -279,7 +279,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "txt"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
mock_verify.return_value = True
with pytest.raises(UnsupportedFileTypeError):
@ -289,7 +289,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with (
patch("services.file_service.file_helpers.verify_file_signature") as mock_verify,
@ -309,7 +309,7 @@ class TestFileService:
file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign")
def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify:
mock_verify.return_value = True
with pytest.raises(NotFound, match="File not found or signature is invalid"):
@ -321,7 +321,7 @@ class TestFileService:
upload_file.extension = "png"
upload_file.mime_type = "image/png"
upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.storage") as mock_storage:
mock_storage.load.return_value = b"image content"
@ -330,7 +330,7 @@ class TestFileService:
assert mime == "image/png"
def test_get_public_image_preview_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found or signature is invalid"):
file_service.get_public_image_preview("file_id")
@ -338,7 +338,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "txt"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with pytest.raises(UnsupportedFileTypeError):
file_service.get_public_image_preview("file_id")
@ -346,7 +346,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.storage") as mock_storage:
mock_storage.load.return_value = b"hello world"
@ -354,7 +354,7 @@ class TestFileService:
assert result == "hello world"
def test_get_file_content_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"):
file_service.get_file_content("file_id")

View File

@ -657,7 +657,7 @@ def _app(**kwargs: Any) -> App:
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_session = MagicMock()
fake_session.query.return_value = _FakeQuery(None)
fake_session.scalar.return_value = None
_patch_session(monkeypatch, fake_session)
# Act / Assert
@ -671,7 +671,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_foun
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)]
fake_session.scalar.side_effect = [webhook_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
@ -686,7 +686,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_lim
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
@ -701,7 +701,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
@ -714,7 +714,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(m
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)]
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
@ -732,7 +732,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mod
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)]
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act
@ -751,7 +751,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(mo
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)]
fake_session.scalar.side_effect = [webhook_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act