Merge branch 'main' into 4-13-update-deps

This commit is contained in:
Stephen Zhou
2026-04-14 09:19:19 +08:00
committed by GitHub
58 changed files with 361 additions and 240 deletions

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import TypedDict
from flask import request
@ -13,6 +14,14 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US"
class NotificationLangContent(TypedDict, total=False):
lang: str
title: str
subtitle: str
body: str
titlePicUrl: str
class NotificationItemDict(TypedDict):
notification_id: str | None
frequency: str | None
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
notifications: list[NotificationItemDict]
def _pick_lang_content(contents: dict, lang: str) -> dict:
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
"""Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
return (
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
)
class DismissNotificationPayload(BaseModel):
@ -71,7 +82,7 @@ class NotificationApi(Resource):
notifications: list[NotificationItemDict] = []
for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {}
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
item: NotificationItemDict = {
"notification_id": notification.get("notificationId"),

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from typing import Any, Literal
import pytz
from flask import request
@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict:
def _serialize_account(account) -> dict[str, Any]:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity
from graphon.variables.input_entities import VariableEntity, VariableEntityType
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form]
def _create_variable_entity(self, item: dict) -> VariableEntity:
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
"""Create a single VariableEntity from raw form item"""
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
try:
variable_type = VariableEntityType(variable_type_raw)
except ValueError as e:
raise MCPRequestError(
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
) from e
variable = item[variable_type_raw]
return VariableEntity(
type=variable_type,
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
json_schema=variable.get("json_schema"),
)
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request"""
try:
return mcp_types.ClientRequest.model_validate(args)

View File

@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
"""Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict = {}
summaries: dict[str, str | None] = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = []
result: list[dict[str, Any]] = []
for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result

View File

@ -5,6 +5,7 @@ Web App Human Input Form APIs.
import json
import logging
from datetime import datetime
from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
@ -58,10 +59,19 @@ def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
class FormDefinitionPayload(TypedDict):
form_content: Any
inputs: Any
resolved_default_values: dict[str, str]
user_actions: Any
expiration_time: int
site: NotRequired[dict]
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
"""Return the form payload (optionally with site) as a JSON response."""
definition_payload = form.get_definition().model_dump()
payload = {
payload: FormDefinitionPayload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),

View File

@ -1,5 +1,6 @@
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any
from flask import make_response, request
from flask_restx import Resource
@ -103,21 +104,23 @@ class PassportResource(Resource):
return response
def decode_enterprise_webapp_user_id(jwt_token: str | None):
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
"""
Decode the enterprise user session from the Authorization header.
"""
if not jwt_token:
return None
decoded = PassportService().verify(jwt_token)
decoded: dict[str, Any] = PassportService().verify(jwt_token)
source = decoded.get("token_source")
if not source or source != "webapp_login_token":
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
def exchange_token_for_existing_web_user(
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
):
"""
Exchange a token for an existing web user session.
"""

View File

@ -1,4 +1,4 @@
from typing import cast
from typing import Any, cast
from flask_restx import fields, marshal, marshal_with
from sqlalchemy import select
@ -113,12 +113,12 @@ class AppSiteInfo:
}
def serialize_site(site: Site) -> dict:
def serialize_site(site: Site) -> dict[str, Any]:
"""Serialize Site model using the same schema as AppSiteApi."""
return cast(dict, marshal(site, AppSiteApi.site_fields))
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))

View File

@ -138,7 +138,9 @@ class DatasetConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for dataset feature
@ -172,7 +174,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
"""
Extract dataset config for legacy compatibility

View File

@ -108,7 +108,7 @@ class ModelConfigManager:
return dict(config), ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict):
def validate_model_completion_params(cls, cp: dict[str, Any]):
# model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")

View File

@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate pre_prompt and set defaults for prompt feature
depending on the config['model']
@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict):
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
"""
Validate post_prompt and set defaults for prompt feature

View File

@ -1,5 +1,5 @@
import re
from typing import cast
from typing import Any, cast
from graphon.variables.input_entities import VariableEntity, VariableEntityType
@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
return variable_entities, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for user input form
@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
return config, related_config_keys
@classmethod
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for user input form
@ -164,7 +164,9 @@ class BasicVariablesConfigManager:
return config, ["user_input_form"]
@classmethod
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_external_data_tools_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for external data fetch feature

View File

@ -30,7 +30,7 @@ class FileUploadConfigManager:
return FileUploadConfig.model_validate(file_upload_dict)
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for file upload feature

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, ValidationError
@ -13,7 +15,7 @@ class AppConfigModel(BaseModel):
class MoreLikeThisConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -23,7 +25,7 @@ class MoreLikeThisConfigManager:
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError:

View File

@ -1,6 +1,9 @@
from typing import Any
class OpeningStatementConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[str, list]:
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class OpeningStatementConfigManager:
return opening_statement, suggested_questions_list
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for opening statement feature

View File

@ -1,6 +1,9 @@
from typing import Any
class RetrievalResourceConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
show_retrieve_source = False
retriever_resource_dict = config.get("retriever_resource")
if retriever_resource_dict:
@ -10,7 +13,7 @@ class RetrievalResourceConfigManager:
return show_retrieve_source
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for retriever resource feature

View File

@ -1,6 +1,9 @@
from typing import Any
class SpeechToTextConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class SpeechToTextConfigManager:
return speech_to_text
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for speech to text feature

View File

@ -1,6 +1,9 @@
from typing import Any
class SuggestedQuestionsAfterAnswerConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
return suggested_questions_after_answer
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for suggested questions feature

View File

@ -1,9 +1,11 @@
from typing import Any
from core.app.app_config.entities import TextToSpeechEntity
class TextToSpeechConfigManager:
@classmethod
def convert(cls, config: dict):
def convert(cls, config: dict[str, Any]):
"""
Convert model config to model config
@ -22,7 +24,7 @@ class TextToSpeechConfigManager:
return text_to_speech
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for text to speech feature

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -1,3 +1,5 @@
from typing import Any
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager):
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
def config_validate(
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
) -> dict[str, Any]:
"""
Validate for pipeline config

View File

@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: dict | None = None,
next_page_parameters: dict[str, Any] | None = None,
):
"""
Get files in a folder.

View File

@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: dict | None = None
extras: dict[str, Any] | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: dict | None = None
extras: dict[str, Any] | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator
@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: dict | None = None
data_source_info: dict[str, Any] | None = None
name: str
indexing_status: str
error: str | None = None

View File

@ -6,6 +6,7 @@ import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Any
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
"""
Get current credentials.
@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
return session.execute(stmt).scalar_one_or_none()
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
"""
Get provider credentials.
@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
"""
Validate custom credentials.
:param credentials: provider credentials
@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None):
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
"""
Add custom provider credentials.
:param credentials: provider credentials
@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
def update_provider_credential(
self,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
credential_name: str | None,
):
@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
def _get_specific_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str
) -> dict | None:
) -> dict[str, Any] | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
def get_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str | None
) -> dict[str, Any] | None:
"""
Get custom model credentials.
@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
self,
model_type: ModelType,
model: str,
credentials: dict,
credentials: dict[str, Any],
credential_id: str = "",
session: Session | None = None,
):
@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
) -> None:
"""
Create a custom model credential.
@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
self,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
credential_name: str | None,
credential_id: str,
) -> None:
"""
Update a custom model credential.
@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
) -> AIModelEntity | None:
"""
Get model schema
"""
@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
"""
Obfuscated credentials.

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from enum import StrEnum, auto
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import BaseModel, ConfigDict, Field
@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
enabled: bool
current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = []
credentials: dict | None = None
credentials: dict[str, Any] | None = None
class CustomProviderConfiguration(BaseModel):
@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
Model class for provider custom configuration.
"""
credentials: dict
credentials: dict[str, Any]
current_credential_id: str | None = None
current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] = []
@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
model: str
model_type: ModelType
credentials: dict | None
credentials: dict[str, Any] | None
current_credential_id: str | None = None
current_credential_name: str | None = None
available_model_credentials: list[CredentialConfiguration] = []

View File

@ -115,7 +115,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True] = True,
@ -126,7 +126,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False] = False,
@ -137,7 +137,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
@ -147,7 +147,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
@ -528,7 +528,7 @@ class LBModelManager:
model_type: ModelType,
model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: dict | None = None,
managed_credentials: dict[str, Any] | None = None,
):
"""
Load balancing model manager

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension
class ModerationInputParams(BaseModel):
app_id: str = ""
inputs: dict = Field(default_factory=dict)
inputs: dict[str, Any] = Field(default_factory=dict)
query: str = ""
@ -23,7 +25,7 @@ class ApiModeration(Moderation):
name: str = "api"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -41,7 +43,7 @@ class ApiModeration(Moderation):
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -73,7 +75,7 @@ class ApiModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
if self.config is None:
raise ValueError("The config is not set.")
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
inputs: dict = Field(default_factory=dict)
inputs: dict[str, Any] = Field(default_factory=dict)
query: str = ""
@ -33,13 +34,13 @@ class Moderation(Extensible, ABC):
module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None):
def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None):
super().__init__(tenant_id, config)
self.app_id = app_id
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None:
"""
Validate the incoming form config data.
@ -50,7 +51,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@abstractmethod
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
@ -75,7 +76,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@classmethod
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool):
# inputs_config
inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict):

View File

@ -1,3 +1,5 @@
from typing import Any
from core.extension.extensible import ExtensionModule
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
from extensions.ext_code_based_extension import code_based_extension
@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension
class ModerationFactory:
__extension_instance: Moderation
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]):
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
self.__extension_instance = extension_class(app_id, tenant_id, config)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict):
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -24,7 +26,7 @@ class ModerationFactory:
# FIXME: mypy error, try to fix it instead of using type: ignore
extension_class.validate_config(tenant_id, config) # type: ignore
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review

View File

@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -28,7 +28,7 @@ class KeywordsModeration(Moderation):
if len(keywords_row_len) > 100:
raise ValueError("the number of rows for the keywords must be less than 100")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -66,7 +66,7 @@ class KeywordsModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool:
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:

View File

@ -1,3 +1,5 @@
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType
from core.model_manager import ModelManager
@ -8,7 +10,7 @@ class OpenAIModeration(Moderation):
name: str = "openai_moderation"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -18,7 +20,7 @@ class OpenAIModeration(Moderation):
"""
cls._validate_inputs_and_outputs_config(config, True)
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -49,7 +51,7 @@ class OpenAIModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict):
def _is_violated(self, inputs: dict[str, Any]):
text = "\n".join(str(inputs.values()))
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
model_instance = model_manager.get_model_instance(

View File

@ -778,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]:
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
attributes: dict[str, str] = {}
@ -797,7 +797,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
set_attribute(path, value)
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
def set_tool_call_attributes(
message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None
) -> None:
"""Extract and assign tool call details safely."""
if not tool_call:
return

View File

@ -242,7 +242,7 @@ class MLflowDataTrace(BaseTraceInstance):
return inputs, attributes
def _parse_knowledge_retrieval_outputs(self, outputs: dict):
def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]):
"""Parse KR outputs and attributes from KR workflow node"""
retrieved = outputs.get("result", [])
@ -319,7 +319,7 @@ class MLflowDataTrace(BaseTraceInstance):
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def _get_message_user_id(self, metadata: dict) -> str | None:
def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None:
if (end_user_id := metadata.get("from_end_user_id")) and (
end_user_data := db.session.get(EndUser, end_user_id)
):
@ -468,7 +468,7 @@ class MLflowDataTrace(BaseTraceInstance):
}
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
def _set_trace_metadata(self, span: Span, metadata: dict):
def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]):
token = None
try:
# NB: Set span in context such that we can use update_current_trace() API
@ -490,7 +490,7 @@ class MLflowDataTrace(BaseTraceInstance):
return messages
return prompts # Fallback to original format
def _parse_single_message(self, item: dict):
def _parse_single_message(self, item: dict[str, Any]):
"""Postprocess single message format to be standard chat message"""
role = item.get("role", "user")
msg = {"role": role, "content": item.get("text", "")}

View File

@ -3,7 +3,7 @@ import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import cast
from typing import Any, cast
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from opik import Opik, Trace
@ -436,7 +436,7 @@ class OpikDataTrace(BaseTraceInstance):
self.add_span(span_data)
def add_trace(self, opik_trace_data: dict) -> Trace:
def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace:
try:
trace = self.opik_client.trace(**opik_trace_data)
logger.debug("Opik Trace created successfully")
@ -444,7 +444,7 @@ class OpikDataTrace(BaseTraceInstance):
except Exception as e:
raise ValueError(f"Opik Failed to create trace: {str(e)}")
def add_span(self, opik_span_data: dict):
def add_span(self, opik_span_data: dict[str, Any]):
try:
self.opik_client.span(**opik_span_data)
logger.debug("Opik Span created successfully")

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field, model_validator
@ -31,7 +32,7 @@ class EndpointEntity(BasePluginEntity):
entity of an endpoint
"""
settings: dict
settings: dict[str, Any]
tenant_id: str
plugin_id: str
expired_at: datetime

View File

@ -1,3 +1,5 @@
from typing import Any
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from pydantic import BaseModel, Field, computed_field, model_validator
@ -40,7 +42,7 @@ class MarketplacePluginDeclaration(BaseModel):
@model_validator(mode="before")
@classmethod
def transform_declaration(cls, data: dict):
def transform_declaration(cls, data: dict[str, Any]) -> dict[str, Any]:
if "endpoint" in data and not data["endpoint"]:
del data["endpoint"]
if "model" in data and not data["model"]:

View File

@ -123,7 +123,7 @@ class PluginDeclaration(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_category(cls, values: dict):
def validate_category(cls, values: dict[str, Any]) -> dict[str, Any]:
# auto detect category
if values.get("tool"):
values["category"] = PluginCategory.Tool

View File

@ -73,7 +73,7 @@ class PluginBasicBooleanResponse(BaseModel):
"""
result: bool
credentials: dict | None = None
credentials: dict[str, Any] | None = None
class PluginModelSchemaEntity(BaseModel):

View File

@ -49,7 +49,7 @@ class RequestInvokeTool(BaseModel):
tool_type: Literal["builtin", "workflow", "api", "mcp"]
provider: str
tool: str
tool_parameters: dict
tool_parameters: dict[str, Any]
credential_id: str | None = None
@ -209,7 +209,7 @@ class RequestInvokeEncrypt(BaseModel):
opt: Literal["encrypt", "decrypt", "clear"]
namespace: Literal["endpoint"]
identity: str
data: dict = Field(default_factory=dict)
data: dict[str, Any] = Field(default_factory=dict)
config: list[BasicProviderConfig] = Field(default_factory=list)

View File

@ -26,7 +26,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
@ -68,7 +68,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
@ -110,7 +110,7 @@ class PluginDatasourceManager(BasePluginClient):
tool_provider_id = DatasourceProviderID(provider_id)
def transformer(json_response: dict[str, Any]) -> dict:
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
data = json_response.get("data")
if data:
for datasource in data.get("declaration", {}).get("datasources", []):

View File

@ -1,3 +1,5 @@
from typing import Any
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
@ -5,7 +7,12 @@ from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
def create_endpoint(
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
self,
tenant_id: str,
user_id: str,
plugin_unique_identifier: str,
name: str,
settings: dict[str, Any],
) -> bool:
"""
Create an endpoint for the given plugin.
@ -49,7 +56,9 @@ class PluginEndpointClient(BasePluginClient):
params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
)
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
def update_endpoint(
self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]
) -> bool:
"""
Update the settings of the given endpoint.
"""

View File

@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
) -> AIModelEntity | None:
"""
Get model schema
@ -80,7 +80,7 @@ class PluginModelClient(BasePluginClient):
return None
def validate_provider_credentials(
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict[str, Any]
) -> bool:
"""
validate the credentials of the provider
@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
) -> bool:
"""
validate the credentials of the provider
@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
texts: list[str],
input_type: str,
) -> EmbeddingResult:
@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
documents: list[dict],
input_type: str,
) -> EmbeddingResult:
@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]:
"""
@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None = None,
@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Generator[bytes, None, None]:
@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
language: str | None = None,
):
"""
@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
file: IO[bytes],
) -> str:
"""
@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
text: str,
) -> bool:
"""

View File

@ -1,4 +1,5 @@
from collections.abc import Sequence
from typing import Any
from requests import HTTPError
@ -263,7 +264,7 @@ class PluginInstaller(BasePluginClient):
original_plugin_unique_identifier: str,
new_plugin_unique_identifier: str,
source: PluginInstallationSource,
meta: dict,
meta: dict[str, Any],
) -> PluginInstallTaskStartResponse:
"""
Upgrade a plugin.

View File

@ -875,7 +875,11 @@ class DatasetRetrieval:
return retrieval_resource_list
def _on_retrieval_end(
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
self,
flask_app: Flask,
documents: list[Document],
message_id: str | None = None,
timer: dict[str, Any] | None = None,
):
"""Handle retrieval end."""
with flask_app.app_context():
@ -980,7 +984,7 @@ class DatasetRetrieval:
self._send_trace_task(message_id, documents, timer)
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict[str, Any] | None):
"""Send trace task if trace manager is available."""
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
@ -1142,7 +1146,7 @@ class DatasetRetrieval:
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
inputs: dict[str, Any],
) -> list[DatasetRetrieverBaseTool] | None:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
@ -1337,7 +1341,7 @@ class DatasetRetrieval:
metadata_filtering_mode: str,
metadata_model_config: ModelConfig,
metadata_filtering_conditions: MetadataFilteringCondition | None,
inputs: dict,
inputs: dict[str, Any],
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
document_query = select(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids),
@ -1417,7 +1421,7 @@ class DatasetRetrieval:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
def _replace_metadata_filter_value(self, text: str, inputs: dict[str, Any]) -> str:
if not inputs:
return text

View File

@ -233,7 +233,7 @@ class DatasetService:
embedding_model_provider: str | None = None,
embedding_model_name: str | None = None,
retrieval_model: RetrievalModel | None = None,
summary_index_setting: dict | None = None,
summary_index_setting: dict[str, Any] | None = None,
):
# check if dataset name already exists
if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)):
@ -2493,7 +2493,7 @@ class DocumentService:
data_source_type: str,
document_form: str,
document_language: str,
data_source_info: dict,
data_source_info: dict[str, Any],
created_from: str,
position: int,
account: Account,
@ -2850,7 +2850,7 @@ class DocumentService:
raise ValueError("Process rule segmentation max_tokens is invalid")
@classmethod
def estimate_args_validate(cls, args: dict):
def estimate_args_validate(cls, args: dict[str, Any]):
if "info_list" not in args or not args["info_list"]:
raise ValueError("Data source info is required")
@ -3132,7 +3132,7 @@ class DocumentService:
class SegmentService:
@classmethod
def segment_create_args_validate(cls, args: dict, document: Document):
def segment_create_args_validate(cls, args: dict[str, Any], document: Document):
if document.doc_form == IndexStructureType.QA_INDEX:
if "answer" not in args or not args["answer"]:
raise ValueError("Answer is required")
@ -3149,7 +3149,7 @@ class SegmentService:
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
def create_segment(cls, args: dict[str, Any], document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None

View File

@ -47,7 +47,7 @@ class ExternalDatasetService:
return external_knowledge_apis.items, external_knowledge_apis.total
@classmethod
def validate_api_list(cls, api_settings: dict):
def validate_api_list(cls, api_settings: dict[str, Any]):
if not api_settings:
raise ValueError("api list is empty")
if not api_settings.get("endpoint"):
@ -56,7 +56,7 @@ class ExternalDatasetService:
raise ValueError("api_key is required")
@staticmethod
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis:
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict[str, Any]) -> ExternalKnowledgeApis:
settings = args.get("settings")
if settings is None:
raise ValueError("settings is required")
@ -75,7 +75,7 @@ class ExternalDatasetService:
return external_knowledge_api
@staticmethod
def check_endpoint_and_api_key(settings: dict):
def check_endpoint_and_api_key(settings: dict[str, Any]):
if "endpoint" not in settings or not settings["endpoint"]:
raise ValueError("endpoint is required")
if "api_key" not in settings or not settings["api_key"]:
@ -178,7 +178,9 @@ class ExternalDatasetService:
return external_knowledge_binding
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
def document_create_args_validate(
tenant_id: str, external_knowledge_api_id: str, process_parameter: dict[str, Any]
):
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
@ -222,7 +224,7 @@ class ExternalDatasetService:
return response
@staticmethod
def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]:
def assembling_headers(authorization: Authorization, headers: dict[str, Any] | None = None) -> dict[str, Any]:
authorization = deepcopy(authorization)
if headers:
headers = deepcopy(headers)
@ -248,11 +250,11 @@ class ExternalDatasetService:
return headers
@staticmethod
def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting:
def get_external_knowledge_api_settings(settings: dict[str, Any]) -> ExternalKnowledgeApiSetting:
return ExternalKnowledgeApiSetting.model_validate(settings)
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
def create_external_dataset(tenant_id: str, user_id: str, args: dict[str, Any]) -> Dataset:
# check if dataset name already exists
if db.session.scalar(
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
@ -304,7 +306,7 @@ class ExternalDatasetService:
tenant_id: str,
dataset_id: str,
query: str,
external_retrieval_parameters: dict,
external_retrieval_parameters: dict[str, Any],
metadata_condition: MetadataFilteringCondition | None = None,
):
external_knowledge_binding = db.session.scalar(

View File

@ -92,7 +92,7 @@ class ApiToolManageService:
@staticmethod
def convert_schema_to_tool_bundles(
schema: str, extra_info: dict | None = None
schema: str, extra_info: dict[str, Any] | None = None
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
convert schema to tool bundles
@ -109,8 +109,8 @@ class ApiToolManageService:
user_id: str,
tenant_id: str,
provider_name: str,
icon: dict,
credentials: dict,
icon: dict[str, Any],
credentials: dict[str, Any],
schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str,
@ -244,8 +244,8 @@ class ApiToolManageService:
tenant_id: str,
provider_name: str,
original_provider: str,
icon: dict,
credentials: dict,
icon: dict[str, Any],
credentials: dict[str, Any],
_schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str | None,
@ -356,8 +356,8 @@ class ApiToolManageService:
tenant_id: str,
provider_name: str,
tool_name: str,
credentials: dict,
parameters: dict,
credentials: dict[str, Any],
parameters: dict[str, Any],
schema_type: ApiProviderSchemaType,
schema: str,
):

View File

@ -147,7 +147,7 @@ class BuiltinToolManageService:
tenant_id: str,
provider: str,
credential_id: str,
credentials: dict | None = None,
credentials: dict[str, Any] | None = None,
name: str | None = None,
):
"""
@ -177,7 +177,7 @@ class BuiltinToolManageService:
)
original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials: dict = {
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
@ -216,7 +216,7 @@ class BuiltinToolManageService:
api_type: CredentialType,
tenant_id: str,
provider: str,
credentials: dict,
credentials: dict[str, Any],
expires_at: int = -1,
name: str | None = None,
):
@ -657,7 +657,7 @@ class BuiltinToolManageService:
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: dict | None = None,
client_params: dict[str, Any] | None = None,
enable_oauth_custom_client: bool | None = None,
):
"""

View File

@ -69,7 +69,9 @@ class ToolTransformService:
return ""
@staticmethod
def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity):
def repack_provider(
tenant_id: str, provider: dict[str, Any] | ToolProviderApiEntity | PluginDatasourceProviderEntity
):
"""
repack provider

View File

@ -1,6 +1,7 @@
import json
import logging
from datetime import datetime
from typing import Any
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import delete, or_, select
@ -35,7 +36,7 @@ class WorkflowToolManageService:
workflow_app_id: str,
name: str,
label: str,
icon: dict,
icon: dict[str, Any],
description: str,
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",
@ -117,7 +118,7 @@ class WorkflowToolManageService:
workflow_tool_id: str,
name: str,
label: str,
icon: dict,
icon: dict[str, Any],
description: str,
parameters: list[WorkflowToolParameterConfiguration],
privacy_policy: str = "",

View File

@ -91,7 +91,7 @@ class WebsiteCrawlApiRequest:
return CrawlRequest(url=self.url, provider=self.provider, options=options)
@classmethod
def from_args(cls, args: dict) -> WebsiteCrawlApiRequest:
def from_args(cls, args: dict[str, Any]) -> WebsiteCrawlApiRequest:
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
url = args.get("url")
@ -115,7 +115,7 @@ class WebsiteCrawlStatusApiRequest:
job_id: str
@classmethod
def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest:
def from_args(cls, args: dict[str, Any], job_id: str) -> WebsiteCrawlStatusApiRequest:
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
if not provider:
@ -163,7 +163,7 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
def _get_decrypted_api_key(cls, tenant_id: str, config: dict[str, Any]) -> str:
"""Decrypt and return the API key from config."""
api_key = config.get("api_key")
if not api_key:
@ -171,7 +171,7 @@ class WebsiteService:
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
@classmethod
def document_create_args_validate(cls, args: dict):
def document_create_args_validate(cls, args: dict[str, Any]):
"""Validate arguments for document creation."""
try:
WebsiteCrawlApiRequest.from_args(args)
@ -195,7 +195,7 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
params: dict[str, Any]
@ -225,7 +225,7 @@ class WebsiteService:
return {"status": "active", "job_id": job_id}
@classmethod
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
# Convert CrawlOptions back to dict format for WaterCrawlProvider
options = {
"limit": request.options.limit,
@ -290,7 +290,7 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> CrawlStatusDict:
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> CrawlStatusDict:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id)
crawl_status_data: CrawlStatusDict = {
@ -364,7 +364,9 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
def _get_firecrawl_url_data(
cls, job_id: str, url: str, api_key: str, config: dict[str, Any]
) -> dict[str, Any] | None:
crawl_data: list[FirecrawlDocumentData] | None = None
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
@ -438,7 +440,7 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
params = {"onlyMainContent": request.only_main_content}
return dict(firecrawl_app.scrape_url(url=request.url, params=params))

View File

@ -11,6 +11,7 @@ from uuid import uuid4
import pytest
from faker import Faker
from sqlalchemy import delete
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
@ -28,12 +29,12 @@ class TestCreateSegmentToIndexTask:
"""Clean up database and Redis before each test to ensure isolation."""
# Clear all test data using fixture session
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(DocumentSegment))
db_session_with_containers.execute(delete(Document))
db_session_with_containers.execute(delete(Dataset))
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache

View File

@ -14,6 +14,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete
from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -41,9 +42,9 @@ class TestSendEmailCodeLoginMailTask:
from extensions.ext_redis import redis_client
# Clear all test data
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache

View File

@ -6,6 +6,7 @@ import pytest
from graphon.enums import WorkflowExecutionStatus
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from sqlalchemy import delete
from configs import dify_config
from core.app.app_config.entities import WorkflowUIBasedAppConfig
@ -30,14 +31,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
@pytest.fixture(autouse=True)
def cleanup_database(db_session_with_containers):
db_session_with_containers.query(HumanInputFormRecipient).delete()
db_session_with_containers.query(HumanInputDelivery).delete()
db_session_with_containers.query(HumanInputForm).delete()
db_session_with_containers.query(WorkflowPause).delete()
db_session_with_containers.query(WorkflowRun).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(HumanInputFormRecipient))
db_session_with_containers.execute(delete(HumanInputDelivery))
db_session_with_containers.execute(delete(HumanInputForm))
db_session_with_containers.execute(delete(WorkflowPause))
db_session_with_containers.execute(delete(WorkflowRun))
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()

View File

@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete, select
from extensions.ext_redis import redis_client
from libs.email_i18n import EmailType
@ -44,9 +45,9 @@ class TestMailInviteMemberTask:
def cleanup_database(self, db_session_with_containers):
"""Clean up database before each test to ensure isolation."""
# Clear all test data
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache
@ -491,10 +492,10 @@ class TestMailInviteMemberTask:
assert tenant.name is not None
# Verify tenant relationship exists
tenant_join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=pending_account.id)
.first()
tenant_join = db_session_with_containers.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == pending_account.id)
.limit(1)
)
assert tenant_join is not None
assert tenant_join.role == TenantAccountRole.NORMAL

View File

@ -4,6 +4,7 @@ from unittest.mock import ANY, call, patch
import pytest
from graphon.variables.segments import StringSegment
from graphon.variables.types import SegmentType
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from extensions.storage.storage_type import StorageType
@ -20,11 +21,11 @@ from tasks.remove_app_and_related_data_task import (
@pytest.fixture(autouse=True)
def cleanup_database(db_session_with_containers):
db_session_with_containers.query(WorkflowDraftVariable).delete()
db_session_with_containers.query(WorkflowDraftVariableFile).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.execute(delete(WorkflowDraftVariable))
db_session_with_containers.execute(delete(WorkflowDraftVariableFile))
db_session_with_containers.execute(delete(UploadFile))
db_session_with_containers.execute(delete(App))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.commit()
@ -127,21 +128,21 @@ class TestDeleteDraftVariablesBatch:
result = delete_draft_variables_batch(app1.id, batch_size=100)
assert result == 150
app1_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app1.id
app1_remaining_count = db_session_with_containers.scalar(
select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app1.id)
)
app2_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app2.id
app2_remaining_count = db_session_with_containers.scalar(
select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app2.id)
)
assert app1_remaining.count() == 0
assert app2_remaining.count() == 100
assert app1_remaining_count == 0
assert app2_remaining_count == 100
def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers):
"""Test deletion when no draft variables exist for the app."""
result = delete_draft_variables_batch(str(uuid.uuid4()), 1000)
assert result == 0
assert db_session_with_containers.query(WorkflowDraftVariable).count() == 0
assert db_session_with_containers.scalar(select(func.count()).select_from(WorkflowDraftVariable)) == 0
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.logger")
@ -190,12 +191,16 @@ class TestDeleteDraftVariableOffloadData:
expected_storage_calls = [call(storage_key) for storage_key in upload_file_keys]
mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True)
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
WorkflowDraftVariableFile.id.in_(file_ids)
remaining_var_files_count = db_session_with_containers.scalar(
select(func.count())
.select_from(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(file_ids))
)
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
assert remaining_var_files.count() == 0
assert remaining_upload_files.count() == 0
remaining_upload_files_count = db_session_with_containers.scalar(
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
)
assert remaining_var_files_count == 0
assert remaining_upload_files_count == 0
@patch("extensions.ext_storage.storage")
@patch("tasks.remove_app_and_related_data_task.logging")
@ -217,9 +222,13 @@ class TestDeleteDraftVariableOffloadData:
assert result == 1
mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", storage_keys[0])
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
WorkflowDraftVariableFile.id.in_(file_ids)
remaining_var_files_count = db_session_with_containers.scalar(
select(func.count())
.select_from(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(file_ids))
)
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
assert remaining_var_files.count() == 0
assert remaining_upload_files.count() == 0
remaining_upload_files_count = db_session_with_containers.scalar(
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
)
assert remaining_var_files_count == 0
assert remaining_upload_files_count == 0

View File

@ -11,6 +11,7 @@ from collections.abc import Generator
from typing import Any
import pytest
from sqlalchemy import delete
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -40,9 +41,9 @@ def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[T
yield tenant, account
# Cleanup
db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
db_session_with_containers.query(Account).filter_by(id=account.id).delete()
db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete()
db_session_with_containers.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == tenant.id))
db_session_with_containers.execute(delete(Account).where(Account.id == account.id))
db_session_with_containers.execute(delete(Tenant).where(Tenant.id == tenant.id))
db_session_with_containers.commit()
@ -93,14 +94,14 @@ def app_model(
)
from models.workflow import Workflow
db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete()
db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete()
db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete()
db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete()
db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete()
db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete()
db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete()
db_session_with_containers.query(App).filter_by(id=app.id).delete()
db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app.id))
db_session_with_containers.execute(delete(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app.id))
db_session_with_containers.execute(delete(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id))
db_session_with_containers.execute(delete(WorkflowPluginTrigger).where(WorkflowPluginTrigger.app_id == app.id))
db_session_with_containers.execute(delete(AppTrigger).where(AppTrigger.app_id == app.id))
db_session_with_containers.execute(delete(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant.id))
db_session_with_containers.execute(delete(Workflow).where(Workflow.app_id == app.id))
db_session_with_containers.execute(delete(App).where(App.id == app.id))
db_session_with_containers.commit()

View File

@ -11,6 +11,7 @@ import pytest
from flask import Flask, Response
from flask.testing import FlaskClient
from graphon.enums import BuiltinNodeTypes
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@ -227,7 +228,9 @@ def test_webhook_trigger_creates_trigger_log(
assert response.status_code == 200
db_session_with_containers.expire_all()
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
logs = db_session_with_containers.scalars(
select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id)
).all()
assert logs, "Webhook trigger should create trigger log"
@ -611,7 +614,9 @@ def test_schedule_trigger_creates_trigger_log(
# Verify WorkflowTriggerLog was created
db_session_with_containers.expire_all()
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
logs = db_session_with_containers.scalars(
select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id)
).all()
assert logs, "Schedule trigger should create WorkflowTriggerLog"
assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE
assert logs[0].root_node_id == schedule_node_id
@ -786,11 +791,12 @@ def test_plugin_trigger_full_chain_with_db_verification(
# Verify database records exist
db_session_with_containers.expire_all()
plugin_triggers = (
db_session_with_containers.query(WorkflowPluginTrigger)
.filter_by(app_id=app_model.id, node_id=plugin_node_id)
.all()
)
plugin_triggers = db_session_with_containers.scalars(
select(WorkflowPluginTrigger).where(
WorkflowPluginTrigger.app_id == app_model.id,
WorkflowPluginTrigger.node_id == plugin_node_id,
)
).all()
assert plugin_triggers, "WorkflowPluginTrigger record should exist"
assert plugin_triggers[0].provider_id == provider_id
assert plugin_triggers[0].event_name == "test_event"