diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 831f0a49e0..4f77b25240 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any +from typing import Any, override from pydantic.fields import FieldInfo from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource @@ -25,6 +25,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource): def __init__(self, settings_cls: type[BaseSettings]): super().__init__(settings_cls) + @override def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: raise NotImplementedError @@ -90,6 +91,7 @@ class DifyConfig( # Thanks for your concentration and consideration. @classmethod + @override def settings_customise_sources( cls, settings_cls: type[BaseSettings], diff --git a/api/configs/remote_settings_sources/apollo/__init__.py b/api/configs/remote_settings_sources/apollo/__init__.py index 55c14ead56..d017b86ad5 100644 --- a/api/configs/remote_settings_sources/apollo/__init__.py +++ b/api/configs/remote_settings_sources/apollo/__init__.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, override from pydantic import Field from pydantic.fields import FieldInfo @@ -48,6 +48,7 @@ class ApolloSettingsSource(RemoteSettingsSource): self.namespace = configs["APOLLO_NAMESPACE"] self.remote_configs = self.client.get_all_dicts(self.namespace) + @override def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: if not isinstance(self.remote_configs, dict): raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}") diff --git a/api/configs/remote_settings_sources/nacos/__init__.py b/api/configs/remote_settings_sources/nacos/__init__.py index f3e6306753..ddef8a5f49 100644 --- a/api/configs/remote_settings_sources/nacos/__init__.py +++ b/api/configs/remote_settings_sources/nacos/__init__.py @@ -1,7 +1,7 @@ import logging import os from collections.abc import Mapping -from typing import Any +from typing import Any, override from pydantic.fields import FieldInfo @@ -41,6 +41,7 @@ class NacosSettingsSource(RemoteSettingsSource): except Exception as e: raise RuntimeError(f"Failed to parse config: {e}") + @override def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: field_value = self.remote_configs.get(field_name) if field_value is None: diff --git a/api/context/execution_context.py b/api/context/execution_context.py index e687dfc4b1..e768297b2f 100644 --- a/api/context/execution_context.py +++ b/api/context/execution_context.py @@ -10,7 +10,7 @@ import threading from abc import ABC, abstractmethod from collections.abc import Callable, Generator from contextlib import AbstractContextManager, contextmanager -from typing import Any, Protocol, final, runtime_checkable +from typing import Any, Protocol, final, runtime_checkable, override from pydantic import BaseModel @@ -133,10 +133,12 @@ class NullAppContext(AppContext): self._config = config or {} self._extensions: dict[str, Any] = {} + @override def get_config(self, key: str, default: Any = None) -> Any: """Get configuration value by key.""" return self._config.get(key, default) + @override def get_extension(self, name: str) -> Any: """Get extension by name.""" return self._extensions.get(name) @@ -146,6 +148,7 @@ class NullAppContext(AppContext): self._extensions[name] = extension @contextmanager + @override def enter(self) -> Generator[None, None, None]: """Enter null context (no-op).""" yield diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index eddd6448d8..1201bad041 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -6,7 +6,7 @@ import contextvars import threading from collections.abc import Generator from contextlib import contextmanager -from typing import Any, final +from typing import Any, final, override from flask import Flask, current_app, g @@ -30,15 +30,18 @@ class FlaskAppContext(AppContext): """ self._flask_app = flask_app + @override def get_config(self, key: str, default: Any = None) -> Any: """Get configuration value from Flask app config.""" return self._flask_app.config.get(key, default) + @override def get_extension(self, name: str) -> Any: """Get Flask extension by name.""" return self._flask_app.extensions.get(name) @contextmanager + @override def enter(self) -> Generator[None, None, None]: """Enter Flask app context.""" with self._flask_app.app_context(): diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index cc763fa89c..0fcfd3aa18 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,7 +1,7 @@ import logging from collections.abc import Mapping from datetime import datetime -from typing import Literal +from typing import Literal, override from dateutil.parser import isoparse from flask import request @@ -79,11 +79,13 @@ def _enum_value(value): class WorkflowRunStatusField(fields.Raw): + @override def output(self, key, obj: WorkflowRun, **kwargs): return _enum_value(obj.status) class WorkflowRunOutputsField(fields.Raw): + @override def output(self, key, obj: WorkflowRun, **kwargs): status = _enum_value(obj.status) if status == WorkflowExecutionStatus.PAUSED.value: diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index a2186be100..04b46c6ebb 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,3 +1,5 @@ +from typing import override + import json from core.agent.cot_agent_runner import CotAgentRunner @@ -66,6 +68,7 @@ class CotChatAgentRunner(CotAgentRunner): return prompt_messages + @override def _organize_prompt_messages(self) -> list[PromptMessage]: """ Organize diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 51a30998ae..0b51cff986 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,3 +1,5 @@ +from typing import override + import json from core.agent.cot_agent_runner import CotAgentRunner @@ -51,6 +53,7 @@ class CotCompletionAgentRunner(CotAgentRunner): return historic_prompt + @override def _organize_prompt_messages(self) -> list[PromptMessage]: """ Organize prompt messages diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index a3cc798352..a06595ac16 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Sequence -from typing import Any +from typing import Any, override from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter @@ -23,6 +23,7 @@ class PluginAgentStrategy(BaseAgentStrategy): self.declaration = declaration self.meta_version = meta_version + @override def get_parameters(self) -> Sequence[AgentStrategyParameter]: return self.declaration.parameters @@ -34,6 +35,7 @@ class PluginAgentStrategy(BaseAgentStrategy): params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name)) return params + @override def _invoke( self, params: dict[str, Any], diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 7cb0c9a8d3..4f3c74deea 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, cast +from typing import Any, cast, override from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -20,6 +20,7 @@ class AdvancedChatAppGenerateResponseConverter( AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse] ): @classmethod + @override def convert_blocking_full_response( cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse ) -> dict[str, Any]: @@ -59,6 +60,7 @@ class AdvancedChatAppGenerateResponseConverter( return response @classmethod + @override def convert_blocking_simple_response( cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse ) -> dict[str, Any]: @@ -76,6 +78,7 @@ class AdvancedChatAppGenerateResponseConverter( return response @classmethod + @override def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, Any, None]: @@ -107,6 +110,7 @@ class AdvancedChatAppGenerateResponseConverter( yield response_chunk @classmethod + @override def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, Any, None]: diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 03bc0a9108..618509101a 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, cast +from typing import Any, cast, override from pydantic import JsonValue @@ -16,6 +16,7 @@ from core.app.entities.task_entities import ( class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]): @classmethod + @override def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking full response. @@ -37,6 +38,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot return response @classmethod + @override def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking simple response. @@ -54,6 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot return response @classmethod + @override def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: @@ -85,6 +88,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot yield response_chunk @classmethod + @override def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 26efcbfafd..0869f0405b 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, cast +from typing import Any, cast, override from pydantic import JsonValue @@ -16,6 +16,7 @@ from core.app.entities.task_entities import ( class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]): @classmethod + @override def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking full response. @@ -37,6 +38,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl return response @classmethod + @override def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking simple response. @@ -54,6 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl return response @classmethod + @override def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: @@ -85,6 +88,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl yield response_chunk @classmethod + @override def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index ad978f58e0..806575c256 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, cast +from typing import Any, cast, override from pydantic import JsonValue @@ -16,6 +16,7 @@ from core.app.entities.task_entities import ( class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]): @classmethod + @override def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): """ Convert blocking full response. @@ -36,6 +37,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple return response @classmethod + @override def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): """ Convert blocking simple response. @@ -53,6 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple return response @classmethod + @override def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: @@ -83,6 +86,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple yield response_chunk @classmethod + @override def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: diff --git a/api/core/app/apps/draft_variable_saver.py b/api/core/app/apps/draft_variable_saver.py index 24018012c5..0048989e79 100644 --- a/api/core/app/apps/draft_variable_saver.py +++ b/api/core/app/apps/draft_variable_saver.py @@ -2,7 +2,7 @@ from __future__ import annotations import abc from collections.abc import Mapping -from typing import Any, Protocol +from typing import Any, Protocol, override from graphon.enums import NodeType @@ -29,5 +29,6 @@ class DraftVariableSaverFactory(Protocol): class NoopDraftVariableSaver(DraftVariableSaver): + @override def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: return None diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 67fc016cba..0b97809bf3 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,3 +1,5 @@ +from typing import override + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom @@ -21,6 +23,7 @@ class MessageBasedAppQueueManager(AppQueueManager): self._app_mode = app_mode self._message_id = str(message_id) + @override def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): """ Publish event to queue diff --git a/api/core/app/apps/pipeline/pipeline_queue_manager.py b/api/core/app/apps/pipeline/pipeline_queue_manager.py index 151b50f238..c34b51c98c 100644 --- a/api/core/app/apps/pipeline/pipeline_queue_manager.py +++ b/api/core/app/apps/pipeline/pipeline_queue_manager.py @@ -1,3 +1,5 @@ +from typing import override + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom @@ -19,6 +21,7 @@ class PipelineQueueManager(AppQueueManager): self._app_mode = app_mode + @override def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ Publish event to queue diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 9985e2d275..fcdd1465d4 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,3 +1,5 @@ +from typing import override + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom @@ -19,6 +21,7 @@ class WorkflowAppQueueManager(AppQueueManager): self._app_mode = app_mode + @override def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): """ Publish event to queue diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 4037388798..93c876c0c4 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, cast +from typing import Any, cast, override from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -18,6 +18,7 @@ class WorkflowAppGenerateResponseConverter( AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse] ): @classmethod + @override def convert_blocking_full_response( cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse ) -> dict[str, Any]: @@ -29,6 +30,7 @@ class WorkflowAppGenerateResponseConverter( return dict(blocking_response.model_dump()) @classmethod + @override def convert_blocking_simple_response( cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse ) -> dict[str, Any]: @@ -40,6 +42,7 @@ class WorkflowAppGenerateResponseConverter( return cls.convert_blocking_full_response(blocking_response) @classmethod + @override def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: @@ -69,6 +72,7 @@ class WorkflowAppGenerateResponseConverter( yield response_chunk @classmethod + @override def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: diff --git a/api/core/app/file_access/controller.py b/api/core/app/file_access/controller.py index 300c187083..7a4e62a928 100644 --- a/api/core/app/file_access/controller.py +++ b/api/core/app/file_access/controller.py @@ -1,3 +1,5 @@ +from typing import override + from __future__ import annotations from collections.abc import Callable @@ -30,9 +32,11 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): ) -> None: self._scope_getter = scope_getter + @override def current_scope(self) -> FileAccessScope | None: return self._scope_getter() + @override def apply_upload_file_filters( self, stmt: Select[tuple[UploadFile]], @@ -52,6 +56,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): UploadFile.created_by == resolved_scope.user_id, ) + @override def apply_tool_file_filters( self, stmt: Select[tuple[ToolFile]], @@ -68,6 +73,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id) + @override def get_upload_file( self, *, @@ -85,6 +91,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): ) return session.scalar(stmt) + @override def get_tool_file( self, *, diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index d5e6b04a4a..765c22edcf 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -7,6 +7,8 @@ core, listens to those generic events, and persists only the `conversation.*` scope updates that matter to chat applications. """ +from typing import override + import logging from core.workflow.system_variables import SystemVariableKey, get_system_text @@ -23,9 +25,11 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): super().__init__() self._conversation_variable_updater = conversation_variable_updater + @override def on_graph_start(self) -> None: pass + @override def on_event(self, event: GraphEngineEvent) -> None: if not isinstance(event, NodeRunVariableUpdatedEvent): return @@ -44,5 +48,6 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable) + @override def on_graph_end(self, error: Exception | None) -> None: pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 9811f9f830..d651721899 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Annotated, Literal, Self +from typing import Annotated, Literal, Self, override from pydantic import BaseModel, Field from sqlalchemy import Engine @@ -83,6 +83,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): def _get_repo(self) -> APIWorkflowRunRepository: return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker) + @override def on_graph_start(self) -> None: """ Called when graph execution starts. @@ -92,6 +93,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): """ pass + @override def on_event(self, event: GraphEngineEvent) -> None: """ Called for every event emitted by the engine. @@ -132,6 +134,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): pause_reasons=event.reasons, ) + @override def on_graph_end(self, error: Exception | None) -> None: """ Called when graph execution ends. diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 1a79a9f843..3e28303a7d 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,3 +1,5 @@ +from typing import override + from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent @@ -9,9 +11,11 @@ class SuspendLayer(GraphEngineLayer): super().__init__() self._paused = False + @override def on_graph_start(self): self._paused = False + @override def on_event(self, event: GraphEngineEvent): """ Handle the paused event, stash runtime state into storage and wait for resume. @@ -19,6 +23,7 @@ class SuspendLayer(GraphEngineLayer): if isinstance(event, GraphRunPausedEvent): self._paused = True + @override def on_graph_end(self, error: Exception | None): """ """ self._paused = False diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index bb9fc1b6fa..094c21944d 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -1,6 +1,6 @@ import logging import uuid -from typing import ClassVar +from typing import ClassVar, override from apscheduler.schedulers.background import BackgroundScheduler # type: ignore @@ -63,6 +63,7 @@ class TimeSliceLayer(GraphEngineLayer): except Exception: logger.exception("scheduler error during check if the workflow need to be suspended") + @override def on_graph_start(self): """ Start timer to check if the workflow need to be suspended. @@ -78,9 +79,11 @@ class TimeSliceLayer(GraphEngineLayer): id=self.schedule_id, ) + @override def on_event(self, event: GraphEngineEvent): pass + @override def on_graph_end(self, error: Exception | None) -> None: self.stopped = True # remove the scheduler diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index b60fe82ffe..65b8af6706 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -1,6 +1,6 @@ import logging from datetime import UTC, datetime -from typing import Any, ClassVar +from typing import Any, ClassVar, override from pydantic import TypeAdapter @@ -37,9 +37,11 @@ class TriggerPostLayer(GraphEngineLayer): self.start_time = start_time self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity + @override def on_graph_start(self): pass + @override def on_event(self, event: GraphEngineEvent): """ Update trigger log with success or failure. @@ -82,5 +84,6 @@ class TriggerPostLayer(GraphEngineLayer): repo.update(trigger_log) session.commit() + @override def on_graph_end(self, error: Exception | None) -> None: pass diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 3a6f9d575a..84795bea44 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -7,7 +7,7 @@ import os import time import urllib.parse from collections.abc import Generator -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, override from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol @@ -40,15 +40,19 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): self._file_access_controller = file_access_controller @property + @override def multimodal_send_format(self) -> str: return dify_config.MULTIMODAL_SEND_FORMAT + @override def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects) + @override def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) + @override def load_file_bytes(self, *, file: File) -> bytes: storage_key = self._resolve_storage_key(file=file) data = storage.load(storage_key, stream=False) @@ -56,6 +60,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): raise ValueError(f"file {storage_key} is not a bytes object") return data + @override def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: if file.transfer_method == FileTransferMethod.REMOTE_URL: return file.remote_url @@ -86,6 +91,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): ) return None + @override def resolve_upload_file_url( self, *, @@ -101,10 +107,12 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): query["as_attachment"] = "true" return f"{url}?{urllib.parse.urlencode(query)}" + @override def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: self._assert_tool_file_access(tool_file_id=tool_file_id) return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) + @override def verify_preview_signature( self, *, diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index d521304615..1cdbbffa5a 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -12,7 +12,7 @@ state. from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime -from typing import Any, Union +from typing import Any, Union, override from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.ops.entities.trace_entity import TraceTaskName @@ -97,12 +97,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer): # ------------------------------------------------------------------ # GraphEngineLayer lifecycle # ------------------------------------------------------------------ + @override def on_graph_start(self) -> None: self._workflow_execution = None self._node_execution_cache.clear() self._node_snapshots.clear() self._node_sequence = 0 + @override def on_event(self, event: GraphEngineEvent) -> None: if isinstance(event, GraphRunStartedEvent): self._handle_graph_run_started() @@ -151,6 +153,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): if isinstance(event, NodeRunPauseRequestedEvent): self._handle_node_pause_requested(event) + @override def on_graph_end(self, error: Exception | None) -> None: return diff --git a/api/core/datasource/local_file/local_file_plugin.py b/api/core/datasource/local_file/local_file_plugin.py index 070a89cb2f..7d20d62cbb 100644 --- a/api/core/datasource/local_file/local_file_plugin.py +++ b/api/core/datasource/local_file/local_file_plugin.py @@ -1,3 +1,5 @@ +from typing import override + from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( @@ -22,8 +24,10 @@ class LocalFileDatasourcePlugin(DatasourcePlugin): self.tenant_id = tenant_id self.plugin_unique_identifier = plugin_unique_identifier + @override def datasource_provider_type(self) -> str: return DatasourceProviderType.LOCAL_FILE + @override def get_icon_url(self, tenant_id: str) -> str: return self.icon diff --git a/api/core/datasource/local_file/local_file_provider.py b/api/core/datasource/local_file/local_file_provider.py index b2b6f51dd3..6b6f78b33d 100644 --- a/api/core/datasource/local_file/local_file_provider.py +++ b/api/core/datasource/local_file/local_file_provider.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, override from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -19,12 +19,14 @@ class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderContro self.plugin_unique_identifier = plugin_unique_identifier @property + @override def provider_type(self) -> DatasourceProviderType: """ returns the type of the provider """ return DatasourceProviderType.LOCAL_FILE + @override def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index ce23da1e09..2fbf575d55 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any +from typing import Any, override from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -67,5 +67,6 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) + @override def datasource_provider_type(self) -> str: return DatasourceProviderType.ONLINE_DOCUMENT diff --git a/api/core/datasource/online_document/online_document_provider.py b/api/core/datasource/online_document/online_document_provider.py index a128b479f4..f1f34c8ba1 100644 --- a/api/core/datasource/online_document/online_document_provider.py +++ b/api/core/datasource/online_document/online_document_provider.py @@ -1,3 +1,5 @@ +from typing import override + from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType @@ -17,6 +19,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC self.plugin_unique_identifier = plugin_unique_identifier @property + @override def provider_type(self) -> DatasourceProviderType: """ returns the type of the provider diff --git a/api/core/datasource/online_drive/online_drive_plugin.py b/api/core/datasource/online_drive/online_drive_plugin.py index 64715226cc..6cdd3b0fcf 100644 --- a/api/core/datasource/online_drive/online_drive_plugin.py +++ b/api/core/datasource/online_drive/online_drive_plugin.py @@ -1,3 +1,5 @@ +from typing import override + from collections.abc import Generator from core.datasource.__base.datasource_plugin import DatasourcePlugin @@ -67,5 +69,6 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) + @override def datasource_provider_type(self) -> str: return DatasourceProviderType.ONLINE_DRIVE diff --git a/api/core/datasource/online_drive/online_drive_provider.py b/api/core/datasource/online_drive/online_drive_provider.py index d0923ed807..d4a6942d09 100644 --- a/api/core/datasource/online_drive/online_drive_provider.py +++ b/api/core/datasource/online_drive/online_drive_provider.py @@ -1,3 +1,5 @@ +from typing import override + from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType @@ -17,6 +19,7 @@ class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderCont self.plugin_unique_identifier = plugin_unique_identifier @property + @override def provider_type(self) -> DatasourceProviderType: """ returns the type of the provider diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index 087ac65a7a..c5c9b4c0f2 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping -from typing import Any +from typing import Any, override from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -47,5 +47,6 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) + @override def datasource_provider_type(self) -> str: return DatasourceProviderType.WEBSITE_CRAWL diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 8c0f20ce2d..0dfdf3c0dd 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,3 +1,5 @@ +from typing import override + from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType @@ -21,6 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon self.plugin_unique_identifier = plugin_unique_identifier @property + @override def provider_type(self) -> DatasourceProviderType: """ returns the type of the provider diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 495fd1d898..8053621bed 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -6,7 +6,7 @@ import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from typing import Any +from typing import Any, override from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select @@ -1889,6 +1889,7 @@ class ProviderConfigurations(BaseModel): key = str(ModelProviderID(key)) return key in self.configurations + @override def __iter__(self): # Return an iterator of (key, value) tuples to match BaseModel's __iter__ yield from self.configurations.items() diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 8ce068cfbb..30fcbc230f 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, TypedDict +from typing import Any, TypedDict, override from sqlalchemy import select @@ -29,6 +29,7 @@ class ApiExternalDataTool(ExternalDataTool): """the unique name of external data tool""" @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -50,6 +51,7 @@ class ApiExternalDataTool(ExternalDataTool): if not api_based_extension: raise ValueError("api_based_extension_id is invalid") + @override def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/helper/code_executor/javascript/javascript_code_provider.py b/api/core/helper/code_executor/javascript/javascript_code_provider.py index ae324b83a9..eb1c27dc0b 100644 --- a/api/core/helper/code_executor/javascript/javascript_code_provider.py +++ b/api/core/helper/code_executor/javascript/javascript_code_provider.py @@ -1,3 +1,5 @@ +from typing import override + from textwrap import dedent from core.helper.code_executor.code_executor import CodeLanguage @@ -6,10 +8,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider class JavascriptCodeProvider(CodeNodeProvider): @staticmethod + @override def get_language() -> str: return CodeLanguage.JAVASCRIPT @classmethod + @override def get_default_code(cls) -> str: return dedent( """ diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py index e28f027a3a..8ff3dc7a78 100644 --- a/api/core/helper/code_executor/javascript/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -1,3 +1,5 @@ +from typing import override + from textwrap import dedent from core.helper.code_executor.template_transformer import TemplateTransformer @@ -5,6 +7,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer class NodeJsTemplateTransformer(TemplateTransformer): @classmethod + @override def get_runner_script(cls) -> str: runner_script = dedent(f""" {cls._code_placeholder} diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index 5e4807401e..9cf5089f7b 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from textwrap import dedent -from typing import Any +from typing import Any, override from core.helper.code_executor.template_transformer import TemplateTransformer @@ -10,6 +10,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): _template_b64_placeholder: str = "{{template_b64}}" @classmethod + @override def transform_response(cls, response: str): """ Transform response to dict @@ -19,6 +20,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): return {"result": cls.extract_result_str_from_response(response)} @classmethod + @override def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str: """ Override base class to use base64 encoding for template code. @@ -34,6 +36,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): return script @classmethod + @override def get_runner_script(cls) -> str: runner_script = dedent(f""" import jinja2 @@ -61,6 +64,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): return runner_script @classmethod + @override def get_preload_script(cls) -> str: preload_script = dedent(""" import jinja2 diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py index 151bf0e201..bb59af0446 100644 --- a/api/core/helper/code_executor/python3/python3_code_provider.py +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -1,3 +1,5 @@ +from typing import override + from textwrap import dedent from core.helper.code_executor.code_executor import CodeLanguage @@ -6,10 +8,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider class Python3CodeProvider(CodeNodeProvider): @staticmethod + @override def get_language() -> str: return CodeLanguage.PYTHON3 @classmethod + @override def get_default_code(cls) -> str: return dedent( """ diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py index ee866eeb81..c1027b382b 100644 --- a/api/core/helper/code_executor/python3/python3_transformer.py +++ b/api/core/helper/code_executor/python3/python3_transformer.py @@ -1,3 +1,5 @@ +from typing import override + from textwrap import dedent from core.helper.code_executor.template_transformer import TemplateTransformer @@ -5,6 +7,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer class Python3TemplateTransformer(TemplateTransformer): @classmethod + @override def get_runner_script(cls) -> str: runner_script = dedent(f""" {cls._code_placeholder} diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index 9f167ca49c..6ad08dfe17 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod from json import JSONDecodeError -from typing import Any +from typing import Any, override from extensions.ext_redis import redis_client @@ -47,6 +47,7 @@ class SingletonProviderCredentialsCache(ProviderCredentialsCache): provider_identity=provider_identity, ) + @override def _generate_cache_key(self, **kwargs) -> str: tenant_id = kwargs["tenant_id"] provider_type = kwargs["provider_type"] @@ -61,6 +62,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): def __init__(self, tenant_id: str, provider: str, credential_id: str): super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id) + @override def _generate_cache_key(self, **kwargs) -> str: tenant_id = kwargs["tenant_id"] provider = kwargs["provider"] diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py index dee1432363..4927149bfb 100644 --- a/api/core/logging/filters.py +++ b/api/core/logging/filters.py @@ -1,5 +1,7 @@ """Logging filters for structured logging.""" +from typing import override + import contextlib import logging @@ -15,6 +17,7 @@ class TraceContextFilter(logging.Filter): Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id. """ + @override def filter(self, record: logging.LogRecord) -> bool: # Get trace context from OpenTelemetry trace_id, span_id = self._get_otel_context() @@ -54,6 +57,7 @@ class IdentityContextFilter(logging.Filter): Extracts tenant_id, user_id, and user_type from Flask-Login current_user. """ + @override def filter(self, record: logging.LogRecord) -> bool: identity = self._extract_identity() record.tenant_id = identity.get("tenant_id", "") diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py index ae7be91c17..56ea748242 100644 --- a/api/core/logging/structured_formatter.py +++ b/api/core/logging/structured_formatter.py @@ -3,7 +3,7 @@ import logging import traceback from datetime import UTC, datetime -from typing import Any, NotRequired, TypedDict +from typing import Any, NotRequired, TypedDict, override import orjson @@ -58,6 +58,7 @@ class StructuredJSONFormatter(logging.Formatter): super().__init__() self._service_name = service_name or dify_config.APPLICATION_NAME + @override def format(self, record: logging.LogRecord) -> str: log_dict = self._build_log_dict(record) try: diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index 173913196e..64596969ef 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -7,7 +7,7 @@ authentication failures and retries operations after refreshing tokens. import logging from collections.abc import Callable -from typing import Any +from typing import Any, override from sqlalchemy.orm import Session @@ -159,6 +159,7 @@ class MCPClientWithAuthRetry(MCPClient): # Reset retry flag after operation completes self._has_retried = False + @override def __enter__(self): """Enter the context manager with retry support.""" @@ -168,6 +169,7 @@ class MCPClientWithAuthRetry(MCPClient): return self._execute_with_retry(initialize_with_retry) + @override def list_tools(self) -> list[Tool]: """ List available tools from the MCP server with auth retry. @@ -180,6 +182,7 @@ class MCPClientWithAuthRetry(MCPClient): """ return self._execute_with_retry(super().list_tools) + @override def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult: """ Invoke a tool on the MCP server with auth retry. diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index d684fe0dd7..f91295a432 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -1,6 +1,6 @@ import queue from datetime import timedelta -from typing import Any, Protocol +from typing import Any, Protocol, override from pydantic import AnyUrl, TypeAdapter @@ -159,6 +159,7 @@ class ClientSession( types.EmptyResult, ) + @override def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None): """Send a progress notification.""" self.send_notification( @@ -326,6 +327,7 @@ class ClientSession( ) ) + @override def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]): ctx = RequestContext[ClientSession, Any]( request_id=responder.request_id, @@ -351,6 +353,7 @@ class ClientSession( with responder: return responder.respond(types.ClientResult(root=types.EmptyResult())) + @override def _handle_incoming( self, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, @@ -358,6 +361,7 @@ class ClientSession( """Handle incoming messages by forwarding to the message handler.""" self._message_handler(req) + @override def _received_notification(self, notification: types.ServerNotification): """Handle notifications from the server.""" # Process specific notification types diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 28165592fc..ec9a1906f8 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, override from pydantic import BaseModel, Field from sqlalchemy import select @@ -25,6 +25,7 @@ class ApiModeration(Moderation): name: str = "api" @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -43,6 +44,7 @@ class ApiModeration(Moderation): if not extension: raise ValueError("API-based Extension not found. Please check it again.") + @override def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" @@ -59,6 +61,7 @@ class ApiModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) + @override def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 7d80d3a53c..339574556d 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any +from typing import Any, override from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult @@ -8,6 +8,7 @@ class KeywordsModeration(Moderation): name: str = "keywords" @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -28,6 +29,7 @@ class KeywordsModeration(Moderation): if len(keywords_row_len) > 100: raise ValueError("the number of rows for the keywords must be less than 100") + @override def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" @@ -49,6 +51,7 @@ class KeywordsModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) + @override def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 6e6e94502c..4b7a08eb27 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, override from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult @@ -9,6 +9,7 @@ class OpenAIModeration(Moderation): name: str = "openai_moderation" @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -19,6 +20,7 @@ class OpenAIModeration(Moderation): """ cls._validate_inputs_and_outputs_config(config, True) + @override def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" @@ -36,6 +38,7 @@ class OpenAIModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) + @override def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" diff --git a/api/core/plugin/impl/exc.py b/api/core/plugin/impl/exc.py index 4cabdc1732..21402f4e7a 100644 --- a/api/core/plugin/impl/exc.py +++ b/api/core/plugin/impl/exc.py @@ -1,3 +1,5 @@ +from typing import override + from collections.abc import Mapping from pydantic import TypeAdapter @@ -11,6 +13,7 @@ class PluginDaemonError(Exception): def __init__(self, description: str): self.description = description + @override def __str__(self) -> str: # returns the class name and description return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}" diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 62573ba2f5..b151055e37 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -4,7 +4,7 @@ import hashlib import logging from collections.abc import Generator, Iterable, Sequence from threading import Lock -from typing import IO, Any, Literal, cast, overload +from typing import IO, Any, Literal, cast, overload, override from pydantic import ValidationError from redis import RedisError @@ -118,6 +118,7 @@ class PluginModelRuntime(ModelRuntime): self._provider_entities = None self._provider_entities_lock = Lock() + @override def fetch_model_providers(self) -> Sequence[ProviderEntity]: if self._provider_entities is not None: return self._provider_entities @@ -130,6 +131,7 @@ class PluginModelRuntime(ModelRuntime): return self._provider_entities + @override def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: provider_schema = self._get_provider_schema(provider) @@ -172,6 +174,7 @@ class PluginModelRuntime(ModelRuntime): mime_type = image_mime_types.get(extension, "image/png") return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type + @override def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: plugin_id, provider_name = self._split_provider(provider) self.client.validate_provider_credentials( @@ -182,6 +185,7 @@ class PluginModelRuntime(ModelRuntime): credentials=credentials, ) + @override def validate_model_credentials( self, *, @@ -201,6 +205,7 @@ class PluginModelRuntime(ModelRuntime): credentials=credentials, ) + @override def get_model_schema( self, *, @@ -267,6 +272,7 @@ class PluginModelRuntime(ModelRuntime): return schema @overload + @override def invoke_llm( self, *, @@ -281,6 +287,7 @@ class PluginModelRuntime(ModelRuntime): ) -> LLMResult: ... @overload + @override def invoke_llm( self, *, @@ -294,6 +301,7 @@ class PluginModelRuntime(ModelRuntime): stream: Literal[True], ) -> Generator[LLMResultChunk, None, None]: ... + @override def invoke_llm( self, *, @@ -330,6 +338,7 @@ class PluginModelRuntime(ModelRuntime): ) @overload + @override def invoke_llm_with_structured_output( self, *, @@ -344,6 +353,7 @@ class PluginModelRuntime(ModelRuntime): ) -> LLMResultWithStructuredOutput: ... @overload + @override def invoke_llm_with_structured_output( self, *, @@ -357,6 +367,7 @@ class PluginModelRuntime(ModelRuntime): stream: Literal[True], ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + @override def invoke_llm_with_structured_output( self, *, @@ -396,6 +407,7 @@ class PluginModelRuntime(ModelRuntime): stream=stream, ) + @override def get_llm_num_tokens( self, *, @@ -422,6 +434,7 @@ class PluginModelRuntime(ModelRuntime): tools=list(tools) if tools else None, ) + @override def invoke_text_embedding( self, *, @@ -443,6 +456,7 @@ class PluginModelRuntime(ModelRuntime): input_type=input_type, ) + @override def invoke_multimodal_embedding( self, *, @@ -464,6 +478,7 @@ class PluginModelRuntime(ModelRuntime): input_type=input_type, ) + @override def get_text_embedding_num_tokens( self, *, @@ -483,6 +498,7 @@ class PluginModelRuntime(ModelRuntime): texts=texts, ) + @override def invoke_rerank( self, *, @@ -508,6 +524,7 @@ class PluginModelRuntime(ModelRuntime): top_n=top_n, ) + @override def invoke_multimodal_rerank( self, *, @@ -533,6 +550,7 @@ class PluginModelRuntime(ModelRuntime): top_n=top_n, ) + @override def invoke_tts( self, *, @@ -554,6 +572,7 @@ class PluginModelRuntime(ModelRuntime): voice=voice, ) + @override def get_tts_model_voices( self, *, @@ -573,6 +592,7 @@ class PluginModelRuntime(ModelRuntime): language=language, ) + @override def invoke_speech_to_text( self, *, @@ -592,6 +612,7 @@ class PluginModelRuntime(ModelRuntime): file=file, ) + @override def invoke_moderation( self, *, diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 392af351b6..f6ece01d99 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, TypedDict +from typing import Any, TypedDict, override import orjson from pydantic import BaseModel @@ -29,6 +29,7 @@ class Jieba(BaseKeyword): super().__init__(dataset) self._config = KeywordTableConfig() + @override def create(self, texts: list[Document], **kwargs) -> BaseKeyword: lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): @@ -48,6 +49,7 @@ class Jieba(BaseKeyword): return self + @override def add_texts(self, texts: list[Document], **kwargs): lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): @@ -72,12 +74,14 @@ class Jieba(BaseKeyword): self._save_dataset_keyword_table(keyword_table) + @override def text_exists(self, id: str) -> bool: keyword_table = self._get_dataset_keyword_table() if keyword_table is None: return False return id in set.union(*keyword_table.values()) + @override def delete_by_ids(self, ids: list[str]): lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): @@ -87,6 +91,7 @@ class Jieba(BaseKeyword): self._save_dataset_keyword_table(keyword_table) + @override def search(self, query: str, **kwargs: Any) -> list[Document]: keyword_table = self._get_dataset_keyword_table() @@ -122,6 +127,7 @@ class Jieba(BaseKeyword): return documents + @override def delete(self): lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 1f82f7a081..a132571dcc 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -2,7 +2,7 @@ import base64 import logging import time from abc import ABC, abstractmethod -from typing import Any +from typing import Any, override from sqlalchemy import select @@ -72,15 +72,19 @@ class _LazyEmbeddings(Embeddings): self._real = CacheEmbedding(embedding_model) return self._real + @override def embed_documents(self, texts: list[str]) -> list[list[float]]: return self._ensure().embed_documents(texts) + @override def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: return self._ensure().embed_multimodal_documents(multimodel_documents) + @override def embed_query(self, text: str) -> list[float]: return self._ensure().embed_query(text) + @override def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: return self._ensure().embed_multimodal_query(multimodel_document)