mirror of
https://github.com/langgenius/dify.git
synced 2026-04-22 03:37:44 +08:00
Merge branch 'main' into 4-13-update-deps
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TypedDict
|
||||
|
||||
from flask import request
|
||||
@ -13,6 +14,14 @@ from services.billing_service import BillingService
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
class NotificationLangContent(TypedDict, total=False):
|
||||
lang: str
|
||||
title: str
|
||||
subtitle: str
|
||||
body: str
|
||||
titlePicUrl: str
|
||||
|
||||
|
||||
class NotificationItemDict(TypedDict):
|
||||
notification_id: str | None
|
||||
frequency: str | None
|
||||
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
|
||||
notifications: list[NotificationItemDict]
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
return (
|
||||
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
|
||||
)
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
@ -71,7 +82,7 @@ class NotificationApi(Resource):
|
||||
|
||||
notifications: list[NotificationItemDict] = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
item: NotificationItemDict = {
|
||||
"notification_id": notification.get("notificationId"),
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict:
|
||||
def _serialize_account(account) -> dict[str, Any]:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Any, Union
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
|
||||
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
||||
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
|
||||
"""Convert raw user input form to VariableEntity objects"""
|
||||
return [self._create_variable_entity(item) for item in raw_form]
|
||||
|
||||
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
||||
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
|
||||
"""Create a single VariableEntity from raw form item"""
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
|
||||
try:
|
||||
variable_type = VariableEntityType(variable_type_raw)
|
||||
except ValueError as e:
|
||||
raise MCPRequestError(
|
||||
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
|
||||
) from e
|
||||
variable = item[variable_type_raw]
|
||||
|
||||
return VariableEntity(
|
||||
type=variable_type,
|
||||
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
|
||||
json_schema=variable.get("json_schema"),
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
"""Parse and validate MCP request"""
|
||||
try:
|
||||
return mcp_types.ClientRequest.model_validate(args)
|
||||
|
||||
@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict = {}
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result = []
|
||||
result: list[dict[str, Any]] = []
|
||||
for segment in segments:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
|
||||
@ -5,6 +5,7 @@ Web App Human Input Form APIs.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -58,10 +59,19 @@ def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class FormDefinitionPayload(TypedDict):
|
||||
form_content: Any
|
||||
inputs: Any
|
||||
resolved_default_values: dict[str, str]
|
||||
user_actions: Any
|
||||
expiration_time: int
|
||||
site: NotRequired[dict]
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
payload: FormDefinitionPayload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
@ -103,21 +104,23 @@ class PassportResource(Resource):
|
||||
return response
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
"""
|
||||
if not jwt_token:
|
||||
return None
|
||||
|
||||
decoded = PassportService().verify(jwt_token)
|
||||
decoded: dict[str, Any] = PassportService().verify(jwt_token)
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "webapp_login_token":
|
||||
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
||||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
|
||||
def exchange_token_for_existing_web_user(
|
||||
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
|
||||
):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from sqlalchemy import select
|
||||
@ -113,12 +113,12 @@ class AppSiteInfo:
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict:
|
||||
def serialize_site(site: Site) -> dict[str, Any]:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
||||
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
||||
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
@ -138,7 +138,9 @@ class DatasetConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for dataset feature
|
||||
|
||||
@ -172,7 +174,7 @@ class DatasetConfigManager:
|
||||
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
||||
|
||||
@classmethod
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
|
||||
"""
|
||||
Extract dataset config for legacy compatibility
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ class ModelConfigManager:
|
||||
return dict(config), ["model"]
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict):
|
||||
def validate_model_completion_params(cls, cp: dict[str, Any]):
|
||||
# model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate pre_prompt and set defaults for prompt feature
|
||||
depending on the config['model']
|
||||
@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
|
||||
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
||||
|
||||
@classmethod
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict):
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
|
||||
"""
|
||||
Validate post_prompt and set defaults for prompt feature
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
|
||||
return variable_entities, external_data_variables
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
|
||||
return config, related_config_keys
|
||||
|
||||
@classmethod
|
||||
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@ -164,7 +164,9 @@ class BasicVariablesConfigManager:
|
||||
return config, ["user_input_form"]
|
||||
|
||||
@classmethod
|
||||
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_external_data_tools_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for external data fetch feature
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class FileUploadConfigManager:
|
||||
return FileUploadConfig.model_validate(file_upload_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for file upload feature
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
@ -13,7 +15,7 @@ class AppConfigModel(BaseModel):
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -23,7 +25,7 @@ class MoreLikeThisConfigManager:
|
||||
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError:
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class OpeningStatementConfigManager:
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for opening statement feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RetrievalResourceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
@ -10,7 +13,7 @@ class RetrievalResourceConfigManager:
|
||||
return show_retrieve_source
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for retriever resource feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SpeechToTextConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class SpeechToTextConfigManager:
|
||||
return speech_to_text
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for speech to text feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
return suggested_questions_after_answer
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import TextToSpeechEntity
|
||||
|
||||
|
||||
class TextToSpeechConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict):
|
||||
def convert(cls, config: dict[str, Any]):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -22,7 +24,7 @@ class TextToSpeechConfigManager:
|
||||
return text_to_speech
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for text to speech feature
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||
@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager):
|
||||
return pipeline_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
def config_validate(
|
||||
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate for pipeline config
|
||||
|
||||
|
||||
@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user_id: str,
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
next_page_parameters: dict | None = None,
|
||||
next_page_parameters: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
|
||||
@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
extras: dict[str, Any] | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
extras: dict[str, Any] | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
|
||||
id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: dict | None = None
|
||||
data_source_info: dict[str, Any] | None = None
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: str | None = None
|
||||
|
||||
@ -6,6 +6,7 @@ import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get current credentials.
|
||||
|
||||
@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
|
||||
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a specific provider credential by ID.
|
||||
:param credential_id: Credential ID
|
||||
@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = stmt.where(ProviderCredential.id != exclude_id)
|
||||
return session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
|
||||
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get provider credentials.
|
||||
|
||||
@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
|
||||
def validate_provider_credentials(
|
||||
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||
):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
return provider_names
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
:param credentials: provider credentials
|
||||
@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def update_provider_credential(
|
||||
self,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
):
|
||||
@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_specific_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credential_id: str
|
||||
) -> dict | None:
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a specific provider credential by ID.
|
||||
:param credential_id: Credential ID
|
||||
@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
|
||||
return session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
|
||||
def get_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credential_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get custom model credentials.
|
||||
|
||||
@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
):
|
||||
@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create a custom model credential.
|
||||
@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
|
||||
raise
|
||||
|
||||
def update_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_name: str | None,
|
||||
credential_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Update a custom model credential.
|
||||
@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
|
||||
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
|
||||
def get_model_schema(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
|
||||
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
|
||||
"""
|
||||
Obfuscated credentials.
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum, auto
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
|
||||
enabled: bool
|
||||
current_quota_type: ProviderQuotaType | None = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
credentials: dict | None = None
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CustomProviderConfiguration(BaseModel):
|
||||
@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
|
||||
Model class for provider custom configuration.
|
||||
"""
|
||||
|
||||
credentials: dict
|
||||
credentials: dict[str, Any]
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_credentials: list[CredentialConfiguration] = []
|
||||
@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict | None
|
||||
credentials: dict[str, Any] | None
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
@ -115,7 +115,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True] = True,
|
||||
@ -126,7 +126,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False] = False,
|
||||
@ -137,7 +137,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
@ -147,7 +147,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
@ -528,7 +528,7 @@ class LBModelManager:
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
||||
managed_credentials: dict | None = None,
|
||||
managed_credentials: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Load balancing model manager
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension
|
||||
|
||||
class ModerationInputParams(BaseModel):
|
||||
app_id: str = ""
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
@ -23,7 +25,7 @@ class ApiModeration(Moderation):
|
||||
name: str = "api"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -41,7 +43,7 @@ class ApiModeration(Moderation):
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -73,7 +75,7 @@ class ApiModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel):
|
||||
flagged: bool = False
|
||||
action: ModerationAction
|
||||
preset_response: str = ""
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
@ -33,13 +34,13 @@ class Moderation(Extensible, ABC):
|
||||
|
||||
module: ExtensionModule = ExtensionModule.MODERATION
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None):
|
||||
def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None):
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -50,7 +51,7 @@ class Moderation(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
@ -75,7 +76,7 @@ class Moderation(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool):
|
||||
# inputs_config
|
||||
inputs_config = config.get("inputs_config")
|
||||
if not isinstance(inputs_config, dict):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
class ModerationFactory:
|
||||
__extension_instance: Moderation
|
||||
|
||||
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
|
||||
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]):
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
self.__extension_instance = extension_class(app_id, tenant_id, config)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -24,7 +26,7 @@ class ModerationFactory:
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
|
||||
@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
|
||||
name: str = "keywords"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -28,7 +28,7 @@ class KeywordsModeration(Moderation):
|
||||
if len(keywords_row_len) > 100:
|
||||
raise ValueError("the number of rows for the keywords must be less than 100")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -66,7 +66,7 @@ class KeywordsModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool:
|
||||
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||
|
||||
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
@ -8,7 +10,7 @@ class OpenAIModeration(Moderation):
|
||||
name: str = "openai_moderation"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -18,7 +20,7 @@ class OpenAIModeration(Moderation):
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -49,7 +51,7 @@ class OpenAIModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
def _is_violated(self, inputs: dict[str, Any]):
|
||||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
|
||||
@ -778,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]:
|
||||
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
|
||||
attributes: dict[str, str] = {}
|
||||
|
||||
@ -797,7 +797,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
|
||||
set_attribute(path, value)
|
||||
|
||||
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
|
||||
def set_tool_call_attributes(
|
||||
message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None
|
||||
) -> None:
|
||||
"""Extract and assign tool call details safely."""
|
||||
if not tool_call:
|
||||
return
|
||||
|
||||
@ -242,7 +242,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
|
||||
return inputs, attributes
|
||||
|
||||
def _parse_knowledge_retrieval_outputs(self, outputs: dict):
|
||||
def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]):
|
||||
"""Parse KR outputs and attributes from KR workflow node"""
|
||||
retrieved = outputs.get("result", [])
|
||||
|
||||
@ -319,7 +319,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def _get_message_user_id(self, metadata: dict) -> str | None:
|
||||
def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None:
|
||||
if (end_user_id := metadata.get("from_end_user_id")) and (
|
||||
end_user_data := db.session.get(EndUser, end_user_id)
|
||||
):
|
||||
@ -468,7 +468,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
}
|
||||
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
|
||||
|
||||
def _set_trace_metadata(self, span: Span, metadata: dict):
|
||||
def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]):
|
||||
token = None
|
||||
try:
|
||||
# NB: Set span in context such that we can use update_current_trace() API
|
||||
@ -490,7 +490,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
return messages
|
||||
return prompts # Fallback to original format
|
||||
|
||||
def _parse_single_message(self, item: dict):
|
||||
def _parse_single_message(self, item: dict[str, Any]):
|
||||
"""Postprocess single message format to be standard chat message"""
|
||||
role = item.get("role", "user")
|
||||
msg = {"role": role, "content": item.get("text", "")}
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from opik import Opik, Trace
|
||||
@ -436,7 +436,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def add_trace(self, opik_trace_data: dict) -> Trace:
|
||||
def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace:
|
||||
try:
|
||||
trace = self.opik_client.trace(**opik_trace_data)
|
||||
logger.debug("Opik Trace created successfully")
|
||||
@ -444,7 +444,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, opik_span_data: dict):
|
||||
def add_span(self, opik_span_data: dict[str, Any]):
|
||||
try:
|
||||
self.opik_client.span(**opik_span_data)
|
||||
logger.debug("Opik Span created successfully")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -31,7 +32,7 @@ class EndpointEntity(BasePluginEntity):
|
||||
entity of an endpoint
|
||||
"""
|
||||
|
||||
settings: dict
|
||||
settings: dict[str, Any]
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
expired_at: datetime
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||
|
||||
@ -40,7 +42,7 @@ class MarketplacePluginDeclaration(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_declaration(cls, data: dict):
|
||||
def transform_declaration(cls, data: dict[str, Any]) -> dict[str, Any]:
|
||||
if "endpoint" in data and not data["endpoint"]:
|
||||
del data["endpoint"]
|
||||
if "model" in data and not data["model"]:
|
||||
|
||||
@ -123,7 +123,7 @@ class PluginDeclaration(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_category(cls, values: dict):
|
||||
def validate_category(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
# auto detect category
|
||||
if values.get("tool"):
|
||||
values["category"] = PluginCategory.Tool
|
||||
|
||||
@ -73,7 +73,7 @@ class PluginBasicBooleanResponse(BaseModel):
|
||||
"""
|
||||
|
||||
result: bool
|
||||
credentials: dict | None = None
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PluginModelSchemaEntity(BaseModel):
|
||||
|
||||
@ -49,7 +49,7 @@ class RequestInvokeTool(BaseModel):
|
||||
tool_type: Literal["builtin", "workflow", "api", "mcp"]
|
||||
provider: str
|
||||
tool: str
|
||||
tool_parameters: dict
|
||||
tool_parameters: dict[str, Any]
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
@ -209,7 +209,7 @@ class RequestInvokeEncrypt(BaseModel):
|
||||
opt: Literal["encrypt", "decrypt", "clear"]
|
||||
namespace: Literal["endpoint"]
|
||||
identity: str
|
||||
data: dict = Field(default_factory=dict)
|
||||
data: dict[str, Any] = Field(default_factory=dict)
|
||||
config: list[BasicProviderConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
@ -68,7 +68,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
@ -110,7 +110,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
tool_provider_id = DatasourceProviderID(provider_id)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for datasource in data.get("declaration", {}).get("datasources", []):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.endpoint import EndpointEntityWithInstance
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.impl.exc import PluginDaemonInternalServerError
|
||||
@ -5,7 +7,12 @@ from core.plugin.impl.exc import PluginDaemonInternalServerError
|
||||
|
||||
class PluginEndpointClient(BasePluginClient):
|
||||
def create_endpoint(
|
||||
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
name: str,
|
||||
settings: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Create an endpoint for the given plugin.
|
||||
@ -49,7 +56,9 @@ class PluginEndpointClient(BasePluginClient):
|
||||
params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
|
||||
def update_endpoint(
|
||||
self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Update the settings of the given endpoint.
|
||||
"""
|
||||
|
||||
@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
@ -80,7 +80,7 @@ class PluginModelClient(BasePluginClient):
|
||||
return None
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
|
||||
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
input_type: str,
|
||||
) -> EmbeddingResult:
|
||||
@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
documents: list[dict],
|
||||
input_type: str,
|
||||
) -> EmbeddingResult:
|
||||
@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
) -> list[int]:
|
||||
"""
|
||||
@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Generator[bytes, None, None]:
|
||||
@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
language: str | None = None,
|
||||
):
|
||||
"""
|
||||
@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
file: IO[bytes],
|
||||
) -> str:
|
||||
"""
|
||||
@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
text: str,
|
||||
) -> bool:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
@ -263,7 +264,7 @@ class PluginInstaller(BasePluginClient):
|
||||
original_plugin_unique_identifier: str,
|
||||
new_plugin_unique_identifier: str,
|
||||
source: PluginInstallationSource,
|
||||
meta: dict,
|
||||
meta: dict[str, Any],
|
||||
) -> PluginInstallTaskStartResponse:
|
||||
"""
|
||||
Upgrade a plugin.
|
||||
|
||||
@ -875,7 +875,11 @@ class DatasetRetrieval:
|
||||
return retrieval_resource_list
|
||||
|
||||
def _on_retrieval_end(
|
||||
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
|
||||
self,
|
||||
flask_app: Flask,
|
||||
documents: list[Document],
|
||||
message_id: str | None = None,
|
||||
timer: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Handle retrieval end."""
|
||||
with flask_app.app_context():
|
||||
@ -980,7 +984,7 @@ class DatasetRetrieval:
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
|
||||
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
|
||||
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict[str, Any] | None):
|
||||
"""Send trace task if trace manager is available."""
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
@ -1142,7 +1146,7 @@ class DatasetRetrieval:
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
) -> list[DatasetRetrieverBaseTool] | None:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
@ -1337,7 +1341,7 @@ class DatasetRetrieval:
|
||||
metadata_filtering_mode: str,
|
||||
metadata_model_config: ModelConfig,
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
|
||||
document_query = select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||
@ -1417,7 +1421,7 @@ class DatasetRetrieval:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||
return metadata_filter_document_ids, metadata_condition
|
||||
|
||||
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
|
||||
def _replace_metadata_filter_value(self, text: str, inputs: dict[str, Any]) -> str:
|
||||
if not inputs:
|
||||
return text
|
||||
|
||||
|
||||
@ -233,7 +233,7 @@ class DatasetService:
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model_name: str | None = None,
|
||||
retrieval_model: RetrievalModel | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: dict[str, Any] | None = None,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)):
|
||||
@ -2493,7 +2493,7 @@ class DocumentService:
|
||||
data_source_type: str,
|
||||
document_form: str,
|
||||
document_language: str,
|
||||
data_source_info: dict,
|
||||
data_source_info: dict[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
account: Account,
|
||||
@ -2850,7 +2850,7 @@ class DocumentService:
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
@classmethod
|
||||
def estimate_args_validate(cls, args: dict):
|
||||
def estimate_args_validate(cls, args: dict[str, Any]):
|
||||
if "info_list" not in args or not args["info_list"]:
|
||||
raise ValueError("Data source info is required")
|
||||
|
||||
@ -3132,7 +3132,7 @@ class DocumentService:
|
||||
|
||||
class SegmentService:
|
||||
@classmethod
|
||||
def segment_create_args_validate(cls, args: dict, document: Document):
|
||||
def segment_create_args_validate(cls, args: dict[str, Any], document: Document):
|
||||
if document.doc_form == IndexStructureType.QA_INDEX:
|
||||
if "answer" not in args or not args["answer"]:
|
||||
raise ValueError("Answer is required")
|
||||
@ -3149,7 +3149,7 @@ class SegmentService:
|
||||
raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}")
|
||||
|
||||
@classmethod
|
||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||
def create_segment(cls, args: dict[str, Any], document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ class ExternalDatasetService:
|
||||
return external_knowledge_apis.items, external_knowledge_apis.total
|
||||
|
||||
@classmethod
|
||||
def validate_api_list(cls, api_settings: dict):
|
||||
def validate_api_list(cls, api_settings: dict[str, Any]):
|
||||
if not api_settings:
|
||||
raise ValueError("api list is empty")
|
||||
if not api_settings.get("endpoint"):
|
||||
@ -56,7 +56,7 @@ class ExternalDatasetService:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
@staticmethod
|
||||
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis:
|
||||
def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict[str, Any]) -> ExternalKnowledgeApis:
|
||||
settings = args.get("settings")
|
||||
if settings is None:
|
||||
raise ValueError("settings is required")
|
||||
@ -75,7 +75,7 @@ class ExternalDatasetService:
|
||||
return external_knowledge_api
|
||||
|
||||
@staticmethod
|
||||
def check_endpoint_and_api_key(settings: dict):
|
||||
def check_endpoint_and_api_key(settings: dict[str, Any]):
|
||||
if "endpoint" not in settings or not settings["endpoint"]:
|
||||
raise ValueError("endpoint is required")
|
||||
if "api_key" not in settings or not settings["api_key"]:
|
||||
@ -178,7 +178,9 @@ class ExternalDatasetService:
|
||||
return external_knowledge_binding
|
||||
|
||||
@staticmethod
|
||||
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
|
||||
def document_create_args_validate(
|
||||
tenant_id: str, external_knowledge_api_id: str, process_parameter: dict[str, Any]
|
||||
):
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
@ -222,7 +224,7 @@ class ExternalDatasetService:
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]:
|
||||
def assembling_headers(authorization: Authorization, headers: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
authorization = deepcopy(authorization)
|
||||
if headers:
|
||||
headers = deepcopy(headers)
|
||||
@ -248,11 +250,11 @@ class ExternalDatasetService:
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting:
|
||||
def get_external_knowledge_api_settings(settings: dict[str, Any]) -> ExternalKnowledgeApiSetting:
|
||||
return ExternalKnowledgeApiSetting.model_validate(settings)
|
||||
|
||||
@staticmethod
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict[str, Any]) -> Dataset:
|
||||
# check if dataset name already exists
|
||||
if db.session.scalar(
|
||||
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
|
||||
@ -304,7 +306,7 @@ class ExternalDatasetService:
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_parameters: dict,
|
||||
external_retrieval_parameters: dict[str, Any],
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
):
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
|
||||
@ -92,7 +92,7 @@ class ApiToolManageService:
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(
|
||||
schema: str, extra_info: dict | None = None
|
||||
schema: str, extra_info: dict[str, Any] | None = None
|
||||
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
@ -109,8 +109,8 @@ class ApiToolManageService:
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
icon: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
@ -244,8 +244,8 @@ class ApiToolManageService:
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
original_provider: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
icon: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
_schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
privacy_policy: str | None,
|
||||
@ -356,8 +356,8 @@ class ApiToolManageService:
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
credentials: dict[str, Any],
|
||||
parameters: dict[str, Any],
|
||||
schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
):
|
||||
|
||||
@ -147,7 +147,7 @@ class BuiltinToolManageService:
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credential_id: str,
|
||||
credentials: dict | None = None,
|
||||
credentials: dict[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""
|
||||
@ -177,7 +177,7 @@ class BuiltinToolManageService:
|
||||
)
|
||||
|
||||
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||
new_credentials: dict = {
|
||||
new_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
@ -216,7 +216,7 @@ class BuiltinToolManageService:
|
||||
api_type: CredentialType,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
expires_at: int = -1,
|
||||
name: str | None = None,
|
||||
):
|
||||
@ -657,7 +657,7 @@ class BuiltinToolManageService:
|
||||
def save_custom_oauth_client_params(
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
client_params: dict | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
enable_oauth_custom_client: bool | None = None,
|
||||
):
|
||||
"""
|
||||
|
||||
@ -69,7 +69,9 @@ class ToolTransformService:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity):
|
||||
def repack_provider(
|
||||
tenant_id: str, provider: dict[str, Any] | ToolProviderApiEntity | PluginDatasourceProviderEntity
|
||||
):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import delete, or_, select
|
||||
@ -35,7 +36,7 @@ class WorkflowToolManageService:
|
||||
workflow_app_id: str,
|
||||
name: str,
|
||||
label: str,
|
||||
icon: dict,
|
||||
icon: dict[str, Any],
|
||||
description: str,
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
@ -117,7 +118,7 @@ class WorkflowToolManageService:
|
||||
workflow_tool_id: str,
|
||||
name: str,
|
||||
label: str,
|
||||
icon: dict,
|
||||
icon: dict[str, Any],
|
||||
description: str,
|
||||
parameters: list[WorkflowToolParameterConfiguration],
|
||||
privacy_policy: str = "",
|
||||
|
||||
@ -91,7 +91,7 @@ class WebsiteCrawlApiRequest:
|
||||
return CrawlRequest(url=self.url, provider=self.provider, options=options)
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args: dict) -> WebsiteCrawlApiRequest:
|
||||
def from_args(cls, args: dict[str, Any]) -> WebsiteCrawlApiRequest:
|
||||
"""Create from Flask-RESTful parsed arguments."""
|
||||
provider = args.get("provider")
|
||||
url = args.get("url")
|
||||
@ -115,7 +115,7 @@ class WebsiteCrawlStatusApiRequest:
|
||||
job_id: str
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest:
|
||||
def from_args(cls, args: dict[str, Any], job_id: str) -> WebsiteCrawlStatusApiRequest:
|
||||
"""Create from Flask-RESTful parsed arguments."""
|
||||
provider = args.get("provider")
|
||||
if not provider:
|
||||
@ -163,7 +163,7 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
|
||||
def _get_decrypted_api_key(cls, tenant_id: str, config: dict[str, Any]) -> str:
|
||||
"""Decrypt and return the API key from config."""
|
||||
api_key = config.get("api_key")
|
||||
if not api_key:
|
||||
@ -171,7 +171,7 @@ class WebsiteService:
|
||||
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
def document_create_args_validate(cls, args: dict[str, Any]):
|
||||
"""Validate arguments for document creation."""
|
||||
try:
|
||||
WebsiteCrawlApiRequest.from_args(args)
|
||||
@ -195,7 +195,7 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
|
||||
params: dict[str, Any]
|
||||
@ -225,7 +225,7 @@ class WebsiteService:
|
||||
return {"status": "active", "job_id": job_id}
|
||||
|
||||
@classmethod
|
||||
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
# Convert CrawlOptions back to dict format for WaterCrawlProvider
|
||||
options = {
|
||||
"limit": request.options.limit,
|
||||
@ -290,7 +290,7 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> CrawlStatusDict:
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> CrawlStatusDict:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data: CrawlStatusDict = {
|
||||
@ -364,7 +364,9 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
def _get_firecrawl_url_data(
|
||||
cls, job_id: str, url: str, api_key: str, config: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
crawl_data: list[FirecrawlDocumentData] | None = None
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
@ -438,7 +440,7 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
params = {"onlyMainContent": request.only_main_content}
|
||||
return dict(firecrawl_app.scrape_url(url=request.url, params=params))
|
||||
|
||||
@ -11,6 +11,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -28,12 +29,12 @@ class TestCreateSegmentToIndexTask:
|
||||
"""Clean up database and Redis before each test to ensure isolation."""
|
||||
|
||||
# Clear all test data using fixture session
|
||||
db_session_with_containers.query(DocumentSegment).delete()
|
||||
db_session_with_containers.query(Document).delete()
|
||||
db_session_with_containers.query(Dataset).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.execute(delete(DocumentSegment))
|
||||
db_session_with_containers.execute(delete(Document))
|
||||
db_session_with_containers.execute(delete(Dataset))
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.execute(delete(Account))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
|
||||
@ -14,6 +14,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import delete
|
||||
|
||||
from libs.email_i18n import EmailType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -41,9 +42,9 @@ class TestSendEmailCodeLoginMailTask:
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
# Clear all test data
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.execute(delete(Account))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.nodes.human_input.entities import HumanInputNodeData
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from sqlalchemy import delete
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
@ -30,14 +31,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(db_session_with_containers):
|
||||
db_session_with_containers.query(HumanInputFormRecipient).delete()
|
||||
db_session_with_containers.query(HumanInputDelivery).delete()
|
||||
db_session_with_containers.query(HumanInputForm).delete()
|
||||
db_session_with_containers.query(WorkflowPause).delete()
|
||||
db_session_with_containers.query(WorkflowRun).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.execute(delete(HumanInputFormRecipient))
|
||||
db_session_with_containers.execute(delete(HumanInputDelivery))
|
||||
db_session_with_containers.execute(delete(HumanInputForm))
|
||||
db_session_with_containers.execute(delete(WorkflowPause))
|
||||
db_session_with_containers.execute(delete(WorkflowRun))
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.execute(delete(Account))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.email_i18n import EmailType
|
||||
@ -44,9 +45,9 @@ class TestMailInviteMemberTask:
|
||||
def cleanup_database(self, db_session_with_containers):
|
||||
"""Clean up database before each test to ensure isolation."""
|
||||
# Clear all test data
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.execute(delete(Account))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
@ -491,10 +492,10 @@ class TestMailInviteMemberTask:
|
||||
assert tenant.name is not None
|
||||
|
||||
# Verify tenant relationship exists
|
||||
tenant_join = (
|
||||
db_session_with_containers.query(TenantAccountJoin)
|
||||
.filter_by(tenant_id=tenant.id, account_id=pending_account.id)
|
||||
.first()
|
||||
tenant_join = db_session_with_containers.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == pending_account.id)
|
||||
.limit(1)
|
||||
)
|
||||
assert tenant_join is not None
|
||||
assert tenant_join.role == TenantAccountRole.NORMAL
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import ANY, call, patch
|
||||
import pytest
|
||||
from graphon.variables.segments import StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.storage.storage_type import StorageType
|
||||
@ -20,11 +21,11 @@ from tasks.remove_app_and_related_data_task import (
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(db_session_with_containers):
|
||||
db_session_with_containers.query(WorkflowDraftVariable).delete()
|
||||
db_session_with_containers.query(WorkflowDraftVariableFile).delete()
|
||||
db_session_with_containers.query(UploadFile).delete()
|
||||
db_session_with_containers.query(App).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.execute(delete(WorkflowDraftVariable))
|
||||
db_session_with_containers.execute(delete(WorkflowDraftVariableFile))
|
||||
db_session_with_containers.execute(delete(UploadFile))
|
||||
db_session_with_containers.execute(delete(App))
|
||||
db_session_with_containers.execute(delete(Tenant))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
@ -127,21 +128,21 @@ class TestDeleteDraftVariablesBatch:
|
||||
result = delete_draft_variables_batch(app1.id, batch_size=100)
|
||||
|
||||
assert result == 150
|
||||
app1_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
|
||||
WorkflowDraftVariable.app_id == app1.id
|
||||
app1_remaining_count = db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app1.id)
|
||||
)
|
||||
app2_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
|
||||
WorkflowDraftVariable.app_id == app2.id
|
||||
app2_remaining_count = db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app2.id)
|
||||
)
|
||||
assert app1_remaining.count() == 0
|
||||
assert app2_remaining.count() == 100
|
||||
assert app1_remaining_count == 0
|
||||
assert app2_remaining_count == 100
|
||||
|
||||
def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers):
|
||||
"""Test deletion when no draft variables exist for the app."""
|
||||
result = delete_draft_variables_batch(str(uuid.uuid4()), 1000)
|
||||
|
||||
assert result == 0
|
||||
assert db_session_with_containers.query(WorkflowDraftVariable).count() == 0
|
||||
assert db_session_with_containers.scalar(select(func.count()).select_from(WorkflowDraftVariable)) == 0
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.logger")
|
||||
@ -190,12 +191,16 @@ class TestDeleteDraftVariableOffloadData:
|
||||
expected_storage_calls = [call(storage_key) for storage_key in upload_file_keys]
|
||||
mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True)
|
||||
|
||||
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
|
||||
WorkflowDraftVariableFile.id.in_(file_ids)
|
||||
remaining_var_files_count = db_session_with_containers.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(file_ids))
|
||||
)
|
||||
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
assert remaining_var_files.count() == 0
|
||||
assert remaining_upload_files.count() == 0
|
||||
remaining_upload_files_count = db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
)
|
||||
assert remaining_var_files_count == 0
|
||||
assert remaining_upload_files_count == 0
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
@patch("tasks.remove_app_and_related_data_task.logging")
|
||||
@ -217,9 +222,13 @@ class TestDeleteDraftVariableOffloadData:
|
||||
assert result == 1
|
||||
mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", storage_keys[0])
|
||||
|
||||
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
|
||||
WorkflowDraftVariableFile.id.in_(file_ids)
|
||||
remaining_var_files_count = db_session_with_containers.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(file_ids))
|
||||
)
|
||||
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
assert remaining_var_files.count() == 0
|
||||
assert remaining_upload_files.count() == 0
|
||||
remaining_upload_files_count = db_session_with_containers.scalar(
|
||||
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
)
|
||||
assert remaining_var_files_count == 0
|
||||
assert remaining_upload_files_count == 0
|
||||
|
||||
@ -11,6 +11,7 @@ from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -40,9 +41,9 @@ def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[T
|
||||
yield tenant, account
|
||||
|
||||
# Cleanup
|
||||
db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
||||
db_session_with_containers.query(Account).filter_by(id=account.id).delete()
|
||||
db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete()
|
||||
db_session_with_containers.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == tenant.id))
|
||||
db_session_with_containers.execute(delete(Account).where(Account.id == account.id))
|
||||
db_session_with_containers.execute(delete(Tenant).where(Tenant.id == tenant.id))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
@ -93,14 +94,14 @@ def app_model(
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
|
||||
db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete()
|
||||
db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(App).filter_by(id=app.id).delete()
|
||||
db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(WorkflowPluginTrigger).where(WorkflowPluginTrigger.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(AppTrigger).where(AppTrigger.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant.id))
|
||||
db_session_with_containers.execute(delete(Workflow).where(Workflow.app_id == app.id))
|
||||
db_session_with_containers.execute(delete(App).where(App.id == app.id))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ import pytest
|
||||
from flask import Flask, Response
|
||||
from flask.testing import FlaskClient
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@ -227,7 +228,9 @@ def test_webhook_trigger_creates_trigger_log(
|
||||
assert response.status_code == 200
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
|
||||
logs = db_session_with_containers.scalars(
|
||||
select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id)
|
||||
).all()
|
||||
assert logs, "Webhook trigger should create trigger log"
|
||||
|
||||
|
||||
@ -611,7 +614,9 @@ def test_schedule_trigger_creates_trigger_log(
|
||||
|
||||
# Verify WorkflowTriggerLog was created
|
||||
db_session_with_containers.expire_all()
|
||||
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
|
||||
logs = db_session_with_containers.scalars(
|
||||
select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id)
|
||||
).all()
|
||||
assert logs, "Schedule trigger should create WorkflowTriggerLog"
|
||||
assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE
|
||||
assert logs[0].root_node_id == schedule_node_id
|
||||
@ -786,11 +791,12 @@ def test_plugin_trigger_full_chain_with_db_verification(
|
||||
|
||||
# Verify database records exist
|
||||
db_session_with_containers.expire_all()
|
||||
plugin_triggers = (
|
||||
db_session_with_containers.query(WorkflowPluginTrigger)
|
||||
.filter_by(app_id=app_model.id, node_id=plugin_node_id)
|
||||
.all()
|
||||
)
|
||||
plugin_triggers = db_session_with_containers.scalars(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app_model.id,
|
||||
WorkflowPluginTrigger.node_id == plugin_node_id,
|
||||
)
|
||||
).all()
|
||||
assert plugin_triggers, "WorkflowPluginTrigger record should exist"
|
||||
assert plugin_triggers[0].provider_id == provider_id
|
||||
assert plugin_triggers[0].event_name == "test_event"
|
||||
|
||||
Reference in New Issue
Block a user