refactor: add missing @override decorators to method overrides

Added @override decorator to methods that override parent class methods
but were missing the decorator, as flagged by pyright's
missing-override-decorator check.

Fixes #36406
This commit is contained in:
EvanYao826
2026-05-22 11:02:13 +08:00
parent b95e6f6a7a
commit eb5804e55b
53 changed files with 219 additions and 34 deletions

View File

@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Any
from typing import Any, override
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
@ -25,6 +25,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
def __init__(self, settings_cls: type[BaseSettings]):
super().__init__(settings_cls)
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
@ -90,6 +91,7 @@ class DifyConfig(
# Thanks for your concentration and consideration.
@classmethod
@override
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any
from typing import Any, override
from pydantic import Field
from pydantic.fields import FieldInfo
@ -48,6 +48,7 @@ class ApolloSettingsSource(RemoteSettingsSource):
self.namespace = configs["APOLLO_NAMESPACE"]
self.remote_configs = self.client.get_all_dicts(self.namespace)
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")

View File

@ -1,7 +1,7 @@
import logging
import os
from collections.abc import Mapping
from typing import Any
from typing import Any, override
from pydantic.fields import FieldInfo
@ -41,6 +41,7 @@ class NacosSettingsSource(RemoteSettingsSource):
except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}")
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
field_value = self.remote_configs.get(field_name)
if field_value is None:

View File

@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, runtime_checkable
from typing import Any, Protocol, final, runtime_checkable, override
from pydantic import BaseModel
@ -133,10 +133,12 @@ class NullAppContext(AppContext):
self._config = config or {}
self._extensions: dict[str, Any] = {}
@override
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
return self._config.get(key, default)
@override
def get_extension(self, name: str) -> Any:
"""Get extension by name."""
return self._extensions.get(name)
@ -146,6 +148,7 @@ class NullAppContext(AppContext):
self._extensions[name] = extension
@contextmanager
@override
def enter(self) -> Generator[None, None, None]:
"""Enter null context (no-op)."""
yield

View File

@ -6,7 +6,7 @@ import contextvars
import threading
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, final
from typing import Any, final, override
from flask import Flask, current_app, g
@ -30,15 +30,18 @@ class FlaskAppContext(AppContext):
"""
self._flask_app = flask_app
@override
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value from Flask app config."""
return self._flask_app.config.get(key, default)
@override
def get_extension(self, name: str) -> Any:
"""Get Flask extension by name."""
return self._flask_app.extensions.get(name)
@contextmanager
@override
def enter(self) -> Generator[None, None, None]:
"""Enter Flask app context."""
with self._flask_app.app_context():

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Literal
from typing import Literal, override
from dateutil.parser import isoparse
from flask import request
@ -79,11 +79,13 @@ def _enum_value(value):
class WorkflowRunStatusField(fields.Raw):
@override
def output(self, key, obj: WorkflowRun, **kwargs):
return _enum_value(obj.status)
class WorkflowRunOutputsField(fields.Raw):
@override
def output(self, key, obj: WorkflowRun, **kwargs):
status = _enum_value(obj.status)
if status == WorkflowExecutionStatus.PAUSED.value:

View File

@ -1,3 +1,5 @@
from typing import override
import json
from core.agent.cot_agent_runner import CotAgentRunner
@ -66,6 +68,7 @@ class CotChatAgentRunner(CotAgentRunner):
return prompt_messages
@override
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize

View File

@ -1,3 +1,5 @@
from typing import override
import json
from core.agent.cot_agent_runner import CotAgentRunner
@ -51,6 +53,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
return historic_prompt
@override
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize prompt messages

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Sequence
from typing import Any
from typing import Any, override
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
@ -23,6 +23,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
self.declaration = declaration
self.meta_version = meta_version
@override
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters
@ -34,6 +35,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
return params
@override
def _invoke(
self,
params: dict[str, Any],

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -20,6 +20,7 @@ class AdvancedChatAppGenerateResponseConverter(
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
@classmethod
@override
def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
@ -59,6 +60,7 @@ class AdvancedChatAppGenerateResponseConverter(
return response
@classmethod
@override
def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
@ -76,6 +78,7 @@ class AdvancedChatAppGenerateResponseConverter(
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, Any, None]:
@ -107,6 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, Any, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from pydantic import JsonValue
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
@ -37,6 +38,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
@ -54,6 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -85,6 +88,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from pydantic import JsonValue
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
@ -37,6 +38,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
@ -54,6 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -85,6 +88,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from pydantic import JsonValue
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking full response.
@ -36,6 +37,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking simple response.
@ -53,6 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -83,6 +86,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import abc
from collections.abc import Mapping
from typing import Any, Protocol
from typing import Any, Protocol, override
from graphon.enums import NodeType
@ -29,5 +29,6 @@ class DraftVariableSaverFactory(Protocol):
class NoopDraftVariableSaver(DraftVariableSaver):
@override
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
return None

View File

@ -1,3 +1,5 @@
from typing import override
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
@ -21,6 +23,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
self._app_mode = app_mode
self._message_id = str(message_id)
@override
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue

View File

@ -1,3 +1,5 @@
from typing import override
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
@ -19,6 +21,7 @@ class PipelineQueueManager(AppQueueManager):
self._app_mode = app_mode
@override
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue

View File

@ -1,3 +1,5 @@
from typing import override
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
@ -19,6 +21,7 @@ class WorkflowAppQueueManager(AppQueueManager):
self._app_mode = app_mode
@override
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -18,6 +18,7 @@ class WorkflowAppGenerateResponseConverter(
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
):
@classmethod
@override
def convert_blocking_full_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
@ -29,6 +30,7 @@ class WorkflowAppGenerateResponseConverter(
return dict(blocking_response.model_dump())
@classmethod
@override
def convert_blocking_simple_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
@ -40,6 +42,7 @@ class WorkflowAppGenerateResponseConverter(
return cls.convert_blocking_full_response(blocking_response)
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -69,6 +72,7 @@ class WorkflowAppGenerateResponseConverter(
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -1,3 +1,5 @@
from typing import override
from __future__ import annotations
from collections.abc import Callable
@ -30,9 +32,11 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
) -> None:
self._scope_getter = scope_getter
@override
def current_scope(self) -> FileAccessScope | None:
return self._scope_getter()
@override
def apply_upload_file_filters(
self,
stmt: Select[tuple[UploadFile]],
@ -52,6 +56,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
UploadFile.created_by == resolved_scope.user_id,
)
@override
def apply_tool_file_filters(
self,
stmt: Select[tuple[ToolFile]],
@ -68,6 +73,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
@override
def get_upload_file(
self,
*,
@ -85,6 +91,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
)
return session.scalar(stmt)
@override
def get_tool_file(
self,
*,

View File

@ -7,6 +7,8 @@ core, listens to those generic events, and persists only the `conversation.*`
scope updates that matter to chat applications.
"""
from typing import override
import logging
from core.workflow.system_variables import SystemVariableKey, get_system_text
@ -23,9 +25,11 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
super().__init__()
self._conversation_variable_updater = conversation_variable_updater
@override
def on_graph_start(self) -> None:
pass
@override
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunVariableUpdatedEvent):
return
@ -44,5 +48,6 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable)
@override
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self
from typing import Annotated, Literal, Self, override
from pydantic import BaseModel, Field
from sqlalchemy import Engine
@ -83,6 +83,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
@override
def on_graph_start(self) -> None:
"""
Called when graph execution starts.
@ -92,6 +93,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
"""
pass
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
@ -132,6 +134,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
pause_reasons=event.reasons,
)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.

View File

@ -1,3 +1,5 @@
from typing import override
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
@ -9,9 +11,11 @@ class SuspendLayer(GraphEngineLayer):
super().__init__()
self._paused = False
@override
def on_graph_start(self):
self._paused = False
@override
def on_event(self, event: GraphEngineEvent):
"""
Handle the paused event, stash runtime state into storage and wait for resume.
@ -19,6 +23,7 @@ class SuspendLayer(GraphEngineLayer):
if isinstance(event, GraphRunPausedEvent):
self._paused = True
@override
def on_graph_end(self, error: Exception | None):
""" """
self._paused = False

View File

@ -1,6 +1,6 @@
import logging
import uuid
from typing import ClassVar
from typing import ClassVar, override
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
@ -63,6 +63,7 @@ class TimeSliceLayer(GraphEngineLayer):
except Exception:
logger.exception("scheduler error during check if the workflow need to be suspended")
@override
def on_graph_start(self):
"""
Start timer to check if the workflow need to be suspended.
@ -78,9 +79,11 @@ class TimeSliceLayer(GraphEngineLayer):
id=self.schedule_id,
)
@override
def on_event(self, event: GraphEngineEvent):
pass
@override
def on_graph_end(self, error: Exception | None) -> None:
self.stopped = True
# remove the scheduler

View File

@ -1,6 +1,6 @@
import logging
from datetime import UTC, datetime
from typing import Any, ClassVar
from typing import Any, ClassVar, override
from pydantic import TypeAdapter
@ -37,9 +37,11 @@ class TriggerPostLayer(GraphEngineLayer):
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
@override
def on_graph_start(self):
pass
@override
def on_event(self, event: GraphEngineEvent):
"""
Update trigger log with success or failure.
@ -82,5 +84,6 @@ class TriggerPostLayer(GraphEngineLayer):
repo.update(trigger_log)
session.commit()
@override
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -7,7 +7,7 @@ import os
import time
import urllib.parse
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, override
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
@ -40,15 +40,19 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
self._file_access_controller = file_access_controller
@property
@override
def multimodal_send_format(self) -> str:
return dify_config.MULTIMODAL_SEND_FORMAT
@override
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
@override
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)
@override
def load_file_bytes(self, *, file: File) -> bytes:
storage_key = self._resolve_storage_key(file=file)
data = storage.load(storage_key, stream=False)
@ -56,6 +60,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
raise ValueError(f"file {storage_key} is not a bytes object")
return data
@override
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
if file.transfer_method == FileTransferMethod.REMOTE_URL:
return file.remote_url
@ -86,6 +91,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
)
return None
@override
def resolve_upload_file_url(
self,
*,
@ -101,10 +107,12 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
query["as_attachment"] = "true"
return f"{url}?{urllib.parse.urlencode(query)}"
@override
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
self._assert_tool_file_access(tool_file_id=tool_file_id)
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
@override
def verify_preview_signature(
self,
*,

View File

@ -12,7 +12,7 @@ state.
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from typing import Any, Union, override
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.ops.entities.trace_entity import TraceTaskName
@ -97,12 +97,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
# ------------------------------------------------------------------
# GraphEngineLayer lifecycle
# ------------------------------------------------------------------
@override
def on_graph_start(self) -> None:
self._workflow_execution = None
self._node_execution_cache.clear()
self._node_snapshots.clear()
self._node_sequence = 0
@override
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self._handle_graph_run_started()
@ -151,6 +153,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
if isinstance(event, NodeRunPauseRequestedEvent):
self._handle_node_pause_requested(event)
@override
def on_graph_end(self, error: Exception | None) -> None:
return

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
@ -22,8 +24,10 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
self.tenant_id = tenant_id
self.plugin_unique_identifier = plugin_unique_identifier
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.LOCAL_FILE
@override
def get_icon_url(self, tenant_id: str) -> str:
return self.icon

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -19,12 +19,14 @@ class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderContro
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
@override
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -67,5 +67,6 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type,
)
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DOCUMENT

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
@ -17,6 +19,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider

View File

@ -1,3 +1,5 @@
from typing import override
from collections.abc import Generator
from core.datasource.__base.datasource_plugin import DatasourcePlugin
@ -67,5 +69,6 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type,
)
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DRIVE

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
@ -17,6 +19,7 @@ class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderCont
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping
from typing import Any
from typing import Any, override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -47,5 +47,6 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type,
)
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.WEBSITE_CRAWL

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
@ -21,6 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider

View File

@ -6,7 +6,7 @@ import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Any
from typing import Any, override
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from sqlalchemy import func, select
@ -1889,6 +1889,7 @@ class ProviderConfigurations(BaseModel):
key = str(ModelProviderID(key))
return key in self.configurations
@override
def __iter__(self):
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
yield from self.configurations.items()

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, TypedDict
from typing import Any, TypedDict, override
from sqlalchemy import select
@ -29,6 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
"""the unique name of external data tool"""
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -50,6 +51,7 @@ class ApiExternalDataTool(ExternalDataTool):
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
@override
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
"""
Query the external data tool.

View File

@ -1,3 +1,5 @@
from typing import override
from textwrap import dedent
from core.helper.code_executor.code_executor import CodeLanguage
@ -6,10 +8,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
class JavascriptCodeProvider(CodeNodeProvider):
@staticmethod
@override
def get_language() -> str:
return CodeLanguage.JAVASCRIPT
@classmethod
@override
def get_default_code(cls) -> str:
return dedent(
"""

View File

@ -1,3 +1,5 @@
from typing import override
from textwrap import dedent
from core.helper.code_executor.template_transformer import TemplateTransformer
@ -5,6 +7,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
@override
def get_runner_script(cls) -> str:
runner_script = dedent(f""" {cls._code_placeholder}

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from textwrap import dedent
from typing import Any
from typing import Any, override
from core.helper.code_executor.template_transformer import TemplateTransformer
@ -10,6 +10,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
_template_b64_placeholder: str = "{{template_b64}}"
@classmethod
@override
def transform_response(cls, response: str):
"""
Transform response to dict
@ -19,6 +20,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
return {"result": cls.extract_result_str_from_response(response)}
@classmethod
@override
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
"""
Override base class to use base64 encoding for template code.
@ -34,6 +36,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
return script
@classmethod
@override
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
import jinja2
@ -61,6 +64,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
return runner_script
@classmethod
@override
def get_preload_script(cls) -> str:
preload_script = dedent("""
import jinja2

View File

@ -1,3 +1,5 @@
from typing import override
from textwrap import dedent
from core.helper.code_executor.code_executor import CodeLanguage
@ -6,10 +8,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
class Python3CodeProvider(CodeNodeProvider):
@staticmethod
@override
def get_language() -> str:
return CodeLanguage.PYTHON3
@classmethod
@override
def get_default_code(cls) -> str:
return dedent(
"""

View File

@ -1,3 +1,5 @@
from typing import override
from textwrap import dedent
from core.helper.code_executor.template_transformer import TemplateTransformer
@ -5,6 +7,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
@override
def get_runner_script(cls) -> str:
runner_script = dedent(f""" {cls._code_placeholder}

View File

@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import Any
from typing import Any, override
from extensions.ext_redis import redis_client
@ -47,6 +47,7 @@ class SingletonProviderCredentialsCache(ProviderCredentialsCache):
provider_identity=provider_identity,
)
@override
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_type = kwargs["provider_type"]
@ -61,6 +62,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
def __init__(self, tenant_id: str, provider: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
@override
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"]

View File

@ -1,5 +1,7 @@
"""Logging filters for structured logging."""
from typing import override
import contextlib
import logging
@ -15,6 +17,7 @@ class TraceContextFilter(logging.Filter):
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
"""
@override
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
@ -54,6 +57,7 @@ class IdentityContextFilter(logging.Filter):
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
"""
@override
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")

View File

@ -3,7 +3,7 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any, NotRequired, TypedDict
from typing import Any, NotRequired, TypedDict, override
import orjson
@ -58,6 +58,7 @@ class StructuredJSONFormatter(logging.Formatter):
super().__init__()
self._service_name = service_name or dify_config.APPLICATION_NAME
@override
def format(self, record: logging.LogRecord) -> str:
log_dict = self._build_log_dict(record)
try:

View File

@ -7,7 +7,7 @@ authentication failures and retries operations after refreshing tokens.
import logging
from collections.abc import Callable
from typing import Any
from typing import Any, override
from sqlalchemy.orm import Session
@ -159,6 +159,7 @@ class MCPClientWithAuthRetry(MCPClient):
# Reset retry flag after operation completes
self._has_retried = False
@override
def __enter__(self):
"""Enter the context manager with retry support."""
@ -168,6 +169,7 @@ class MCPClientWithAuthRetry(MCPClient):
return self._execute_with_retry(initialize_with_retry)
@override
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server with auth retry.
@ -180,6 +182,7 @@ class MCPClientWithAuthRetry(MCPClient):
"""
return self._execute_with_retry(super().list_tools)
@override
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server with auth retry.

View File

@ -1,6 +1,6 @@
import queue
from datetime import timedelta
from typing import Any, Protocol
from typing import Any, Protocol, override
from pydantic import AnyUrl, TypeAdapter
@ -159,6 +159,7 @@ class ClientSession(
types.EmptyResult,
)
@override
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
"""Send a progress notification."""
self.send_notification(
@ -326,6 +327,7 @@ class ClientSession(
)
)
@override
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
@ -351,6 +353,7 @@ class ClientSession(
with responder:
return responder.respond(types.ClientResult(root=types.EmptyResult()))
@override
def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
@ -358,6 +361,7 @@ class ClientSession(
"""Handle incoming messages by forwarding to the message handler."""
self._message_handler(req)
@override
def _received_notification(self, notification: types.ServerNotification):
"""Handle notifications from the server."""
# Process specific notification types

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, override
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -25,6 +25,7 @@ class ApiModeration(Moderation):
name: str = "api"
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -43,6 +44,7 @@ class ApiModeration(Moderation):
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
@override
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
@ -59,6 +61,7 @@ class ApiModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
@override
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any
from typing import Any, override
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
@ -8,6 +8,7 @@ class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -28,6 +29,7 @@ class KeywordsModeration(Moderation):
if len(keywords_row_len) > 100:
raise ValueError("the number of rows for the keywords must be less than 100")
@override
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
@ -49,6 +51,7 @@ class KeywordsModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
@override
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, override
from core.model_manager import ModelManager
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
@ -9,6 +9,7 @@ class OpenAIModeration(Moderation):
name: str = "openai_moderation"
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -19,6 +20,7 @@ class OpenAIModeration(Moderation):
"""
cls._validate_inputs_and_outputs_config(config, True)
@override
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
@ -36,6 +38,7 @@ class OpenAIModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
@override
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""

View File

@ -1,3 +1,5 @@
from typing import override
from collections.abc import Mapping
from pydantic import TypeAdapter
@ -11,6 +13,7 @@ class PluginDaemonError(Exception):
def __init__(self, description: str):
self.description = description
@override
def __str__(self) -> str:
# returns the class name and description
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"

View File

@ -4,7 +4,7 @@ import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Literal, cast, overload
from typing import IO, Any, Literal, cast, overload, override
from pydantic import ValidationError
from redis import RedisError
@ -118,6 +118,7 @@ class PluginModelRuntime(ModelRuntime):
self._provider_entities = None
self._provider_entities_lock = Lock()
@override
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
if self._provider_entities is not None:
return self._provider_entities
@ -130,6 +131,7 @@ class PluginModelRuntime(ModelRuntime):
return self._provider_entities
@override
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
provider_schema = self._get_provider_schema(provider)
@ -172,6 +174,7 @@ class PluginModelRuntime(ModelRuntime):
mime_type = image_mime_types.get(extension, "image/png")
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
@override
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_provider_credentials(
@ -182,6 +185,7 @@ class PluginModelRuntime(ModelRuntime):
credentials=credentials,
)
@override
def validate_model_credentials(
self,
*,
@ -201,6 +205,7 @@ class PluginModelRuntime(ModelRuntime):
credentials=credentials,
)
@override
def get_model_schema(
self,
*,
@ -267,6 +272,7 @@ class PluginModelRuntime(ModelRuntime):
return schema
@overload
@override
def invoke_llm(
self,
*,
@ -281,6 +287,7 @@ class PluginModelRuntime(ModelRuntime):
) -> LLMResult: ...
@overload
@override
def invoke_llm(
self,
*,
@ -294,6 +301,7 @@ class PluginModelRuntime(ModelRuntime):
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
@override
def invoke_llm(
self,
*,
@ -330,6 +338,7 @@ class PluginModelRuntime(ModelRuntime):
)
@overload
@override
def invoke_llm_with_structured_output(
self,
*,
@ -344,6 +353,7 @@ class PluginModelRuntime(ModelRuntime):
) -> LLMResultWithStructuredOutput: ...
@overload
@override
def invoke_llm_with_structured_output(
self,
*,
@ -357,6 +367,7 @@ class PluginModelRuntime(ModelRuntime):
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@override
def invoke_llm_with_structured_output(
self,
*,
@ -396,6 +407,7 @@ class PluginModelRuntime(ModelRuntime):
stream=stream,
)
@override
def get_llm_num_tokens(
self,
*,
@ -422,6 +434,7 @@ class PluginModelRuntime(ModelRuntime):
tools=list(tools) if tools else None,
)
@override
def invoke_text_embedding(
self,
*,
@ -443,6 +456,7 @@ class PluginModelRuntime(ModelRuntime):
input_type=input_type,
)
@override
def invoke_multimodal_embedding(
self,
*,
@ -464,6 +478,7 @@ class PluginModelRuntime(ModelRuntime):
input_type=input_type,
)
@override
def get_text_embedding_num_tokens(
self,
*,
@ -483,6 +498,7 @@ class PluginModelRuntime(ModelRuntime):
texts=texts,
)
@override
def invoke_rerank(
self,
*,
@ -508,6 +524,7 @@ class PluginModelRuntime(ModelRuntime):
top_n=top_n,
)
@override
def invoke_multimodal_rerank(
self,
*,
@ -533,6 +550,7 @@ class PluginModelRuntime(ModelRuntime):
top_n=top_n,
)
@override
def invoke_tts(
self,
*,
@ -554,6 +572,7 @@ class PluginModelRuntime(ModelRuntime):
voice=voice,
)
@override
def get_tts_model_voices(
self,
*,
@ -573,6 +592,7 @@ class PluginModelRuntime(ModelRuntime):
language=language,
)
@override
def invoke_speech_to_text(
self,
*,
@ -592,6 +612,7 @@ class PluginModelRuntime(ModelRuntime):
file=file,
)
@override
def invoke_moderation(
self,
*,

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, TypedDict
from typing import Any, TypedDict, override
import orjson
from pydantic import BaseModel
@ -29,6 +29,7 @@ class Jieba(BaseKeyword):
super().__init__(dataset)
self._config = KeywordTableConfig()
@override
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
@ -48,6 +49,7 @@ class Jieba(BaseKeyword):
return self
@override
def add_texts(self, texts: list[Document], **kwargs):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
@ -72,12 +74,14 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table)
@override
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
if keyword_table is None:
return False
return id in set.union(*keyword_table.values())
@override
def delete_by_ids(self, ids: list[str]):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
@ -87,6 +91,7 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table)
@override
def search(self, query: str, **kwargs: Any) -> list[Document]:
keyword_table = self._get_dataset_keyword_table()
@ -122,6 +127,7 @@ class Jieba(BaseKeyword):
return documents
@override
def delete(self):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):

View File

@ -2,7 +2,7 @@ import base64
import logging
import time
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, override
from sqlalchemy import select
@ -72,15 +72,19 @@ class _LazyEmbeddings(Embeddings):
self._real = CacheEmbedding(embedding_model)
return self._real
@override
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._ensure().embed_documents(texts)
@override
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
return self._ensure().embed_multimodal_documents(multimodel_documents)
@override
def embed_query(self, text: str) -> list[float]:
return self._ensure().embed_query(text)
@override
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
return self._ensure().embed_multimodal_query(multimodel_document)