mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 05:37:48 +08:00
refactor: add missing @override decorators to method overrides
Added @override decorator to methods that override parent class methods but were missing the decorator, as flagged by pyright's missing-override-decorator check. Fixes #36406
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
|
||||
@ -25,6 +25,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
||||
def __init__(self, settings_cls: type[BaseSettings]):
|
||||
super().__init__(settings_cls)
|
||||
|
||||
@override
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -90,6 +91,7 @@ class DifyConfig(
|
||||
# Thanks for your concentration and consideration.
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
@ -48,6 +48,7 @@ class ApolloSettingsSource(RemoteSettingsSource):
|
||||
self.namespace = configs["APOLLO_NAMESPACE"]
|
||||
self.remote_configs = self.client.get_all_dicts(self.namespace)
|
||||
|
||||
@override
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
if not isinstance(self.remote_configs, dict):
|
||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
@ -41,6 +41,7 @@ class NacosSettingsSource(RemoteSettingsSource):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to parse config: {e}")
|
||||
|
||||
@override
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
field_value = self.remote_configs.get(field_name)
|
||||
if field_value is None:
|
||||
|
||||
@ -10,7 +10,7 @@ import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
from typing import Any, Protocol, final, runtime_checkable, override
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -133,10 +133,12 @@ class NullAppContext(AppContext):
|
||||
self._config = config or {}
|
||||
self._extensions: dict[str, Any] = {}
|
||||
|
||||
@override
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
return self._config.get(key, default)
|
||||
|
||||
@override
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get extension by name."""
|
||||
return self._extensions.get(name)
|
||||
@ -146,6 +148,7 @@ class NullAppContext(AppContext):
|
||||
self._extensions[name] = extension
|
||||
|
||||
@contextmanager
|
||||
@override
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter null context (no-op)."""
|
||||
yield
|
||||
|
||||
@ -6,7 +6,7 @@ import contextvars
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, final
|
||||
from typing import Any, final, override
|
||||
|
||||
from flask import Flask, current_app, g
|
||||
|
||||
@ -30,15 +30,18 @@ class FlaskAppContext(AppContext):
|
||||
"""
|
||||
self._flask_app = flask_app
|
||||
|
||||
@override
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value from Flask app config."""
|
||||
return self._flask_app.config.get(key, default)
|
||||
|
||||
@override
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name."""
|
||||
return self._flask_app.extensions.get(name)
|
||||
|
||||
@contextmanager
|
||||
@override
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask app context."""
|
||||
with self._flask_app.app_context():
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Literal, override
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@ -79,11 +79,13 @@ def _enum_value(value):
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return _enum_value(obj.status)
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
status = _enum_value(obj.status)
|
||||
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
@ -66,6 +68,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
@override
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
@ -51,6 +53,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
|
||||
return historic_prompt
|
||||
|
||||
@override
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||
@ -23,6 +23,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
self.declaration = declaration
|
||||
self.meta_version = meta_version
|
||||
|
||||
@override
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
@ -34,6 +35,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
|
||||
return params
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -20,6 +20,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -59,6 +60,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -76,6 +78,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -107,6 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -37,6 +38,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -54,6 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -85,6 +88,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -37,6 +38,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -54,6 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -85,6 +88,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -36,6 +37,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -53,6 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -83,6 +86,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Protocol, override
|
||||
|
||||
from graphon.enums import NodeType
|
||||
|
||||
@ -29,5 +29,6 @@ class DraftVariableSaverFactory(Protocol):
|
||||
|
||||
|
||||
class NoopDraftVariableSaver(DraftVariableSaver):
|
||||
@override
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
|
||||
return None
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -21,6 +23,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -19,6 +21,7 @@ class PipelineQueueManager(AppQueueManager):
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -19,6 +21,7 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -18,6 +18,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -29,6 +30,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -40,6 +42,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -69,6 +72,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
@ -30,9 +32,11 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
) -> None:
|
||||
self._scope_getter = scope_getter
|
||||
|
||||
@override
|
||||
def current_scope(self) -> FileAccessScope | None:
|
||||
return self._scope_getter()
|
||||
|
||||
@override
|
||||
def apply_upload_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[UploadFile]],
|
||||
@ -52,6 +56,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
UploadFile.created_by == resolved_scope.user_id,
|
||||
)
|
||||
|
||||
@override
|
||||
def apply_tool_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[ToolFile]],
|
||||
@ -68,6 +73,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
|
||||
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
|
||||
|
||||
@override
|
||||
def get_upload_file(
|
||||
self,
|
||||
*,
|
||||
@ -85,6 +91,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
@override
|
||||
def get_tool_file(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -7,6 +7,8 @@ core, listens to those generic events, and persists only the `conversation.*`
|
||||
scope updates that matter to chat applications.
|
||||
"""
|
||||
|
||||
from typing import override
|
||||
|
||||
import logging
|
||||
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
@ -23,9 +25,11 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
super().__init__()
|
||||
self._conversation_variable_updater = conversation_variable_updater
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunVariableUpdatedEvent):
|
||||
return
|
||||
@ -44,5 +48,6 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Self
|
||||
from typing import Annotated, Literal, Self, override
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Engine
|
||||
@ -83,6 +83,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
@ -92,6 +93,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
"""
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
@ -132,6 +134,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
pause_reasons=event.reasons,
|
||||
)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
|
||||
|
||||
@ -9,9 +11,11 @@ class SuspendLayer(GraphEngineLayer):
|
||||
super().__init__()
|
||||
self._paused = False
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
self._paused = False
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle the paused event, stash runtime state into storage and wait for resume.
|
||||
@ -19,6 +23,7 @@ class SuspendLayer(GraphEngineLayer):
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
self._paused = True
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
""" """
|
||||
self._paused = False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, override
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
|
||||
|
||||
@ -63,6 +63,7 @@ class TimeSliceLayer(GraphEngineLayer):
|
||||
except Exception:
|
||||
logger.exception("scheduler error during check if the workflow need to be suspended")
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
"""
|
||||
Start timer to check if the workflow need to be suspended.
|
||||
@ -78,9 +79,11 @@ class TimeSliceLayer(GraphEngineLayer):
|
||||
id=self.schedule_id,
|
||||
)
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self.stopped = True
|
||||
# remove the scheduler
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any, ClassVar, override
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
@ -37,9 +37,11 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Update trigger log with success or failure.
|
||||
@ -82,5 +84,6 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
||||
@ -7,7 +7,7 @@ import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Literal, override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
@ -40,15 +40,19 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
self._file_access_controller = file_access_controller
|
||||
|
||||
@property
|
||||
@override
|
||||
def multimodal_send_format(self) -> str:
|
||||
return dify_config.MULTIMODAL_SEND_FORMAT
|
||||
|
||||
@override
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
@override
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
@override
|
||||
def load_file_bytes(self, *, file: File) -> bytes:
|
||||
storage_key = self._resolve_storage_key(file=file)
|
||||
data = storage.load(storage_key, stream=False)
|
||||
@ -56,6 +60,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
raise ValueError(f"file {storage_key} is not a bytes object")
|
||||
return data
|
||||
|
||||
@override
|
||||
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return file.remote_url
|
||||
@ -86,6 +91,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
)
|
||||
return None
|
||||
|
||||
@override
|
||||
def resolve_upload_file_url(
|
||||
self,
|
||||
*,
|
||||
@ -101,10 +107,12 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
query["as_attachment"] = "true"
|
||||
return f"{url}?{urllib.parse.urlencode(query)}"
|
||||
|
||||
@override
|
||||
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
self._assert_tool_file_access(tool_file_id=tool_file_id)
|
||||
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
|
||||
|
||||
@override
|
||||
def verify_preview_signature(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -12,7 +12,7 @@ state.
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, override
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
@ -97,12 +97,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
# ------------------------------------------------------------------
|
||||
# GraphEngineLayer lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
self._workflow_execution = None
|
||||
self._node_execution_cache.clear()
|
||||
self._node_snapshots.clear()
|
||||
self._node_sequence = 0
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._handle_graph_run_started()
|
||||
@ -151,6 +153,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
if isinstance(event, NodeRunPauseRequestedEvent):
|
||||
self._handle_node_pause_requested(event)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
@ -22,8 +24,10 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@override
|
||||
def get_icon_url(self, tenant_id: str) -> str:
|
||||
return self.icon
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -19,12 +19,14 @@ class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderContro
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -67,5 +67,6 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -17,6 +19,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
@ -67,5 +69,6 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -17,6 +19,7 @@ class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderCont
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -47,5 +47,6 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -21,6 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -6,7 +6,7 @@ import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from sqlalchemy import func, select
|
||||
@ -1889,6 +1889,7 @@ class ProviderConfigurations(BaseModel):
|
||||
key = str(ModelProviderID(key))
|
||||
return key in self.configurations
|
||||
|
||||
@override
|
||||
def __iter__(self):
|
||||
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
|
||||
yield from self.configurations.items()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -29,6 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
"""the unique name of external data tool"""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -50,6 +51,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
||||
@override
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from textwrap import dedent
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
@ -6,10 +8,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
|
||||
class JavascriptCodeProvider(CodeNodeProvider):
|
||||
@staticmethod
|
||||
@override
|
||||
def get_language() -> str:
|
||||
return CodeLanguage.JAVASCRIPT
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_code(cls) -> str:
|
||||
return dedent(
|
||||
"""
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from textwrap import dedent
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
@ -5,6 +7,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
@ -10,6 +10,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
_template_b64_placeholder: str = "{{template_b64}}"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def transform_response(cls, response: str):
|
||||
"""
|
||||
Transform response to dict
|
||||
@ -19,6 +20,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return {"result": cls.extract_result_str_from_response(response)}
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Override base class to use base64 encoding for template code.
|
||||
@ -34,6 +36,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return script
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
import jinja2
|
||||
@ -61,6 +64,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return runner_script
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_preload_script(cls) -> str:
|
||||
preload_script = dedent("""
|
||||
import jinja2
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from textwrap import dedent
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
@ -6,10 +8,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
|
||||
class Python3CodeProvider(CodeNodeProvider):
|
||||
@staticmethod
|
||||
@override
|
||||
def get_language() -> str:
|
||||
return CodeLanguage.PYTHON3
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_code(cls) -> str:
|
||||
return dedent(
|
||||
"""
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from textwrap import dedent
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
@ -5,6 +7,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
class Python3TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -47,6 +47,7 @@ class SingletonProviderCredentialsCache(ProviderCredentialsCache):
|
||||
provider_identity=provider_identity,
|
||||
)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_type = kwargs["provider_type"]
|
||||
@ -61,6 +62,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider = kwargs["provider"]
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Logging filters for structured logging."""
|
||||
|
||||
from typing import override
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
@ -15,6 +17,7 @@ class TraceContextFilter(logging.Filter):
|
||||
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
|
||||
"""
|
||||
|
||||
@override
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
@ -54,6 +57,7 @@ class IdentityContextFilter(logging.Filter):
|
||||
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
|
||||
"""
|
||||
|
||||
@override
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
from typing import Any, NotRequired, TypedDict, override
|
||||
|
||||
import orjson
|
||||
|
||||
@ -58,6 +58,7 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
super().__init__()
|
||||
self._service_name = service_name or dify_config.APPLICATION_NAME
|
||||
|
||||
@override
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_dict = self._build_log_dict(record)
|
||||
try:
|
||||
|
||||
@ -7,7 +7,7 @@ authentication failures and retries operations after refreshing tokens.
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -159,6 +159,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
# Reset retry flag after operation completes
|
||||
self._has_retried = False
|
||||
|
||||
@override
|
||||
def __enter__(self):
|
||||
"""Enter the context manager with retry support."""
|
||||
|
||||
@ -168,6 +169,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
|
||||
return self._execute_with_retry(initialize_with_retry)
|
||||
|
||||
@override
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""
|
||||
List available tools from the MCP server with auth retry.
|
||||
@ -180,6 +182,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
"""
|
||||
return self._execute_with_retry(super().list_tools)
|
||||
|
||||
@override
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""
|
||||
Invoke a tool on the MCP server with auth retry.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import queue
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Protocol, override
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
@ -159,6 +159,7 @@ class ClientSession(
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
@override
|
||||
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
|
||||
"""Send a progress notification."""
|
||||
self.send_notification(
|
||||
@ -326,6 +327,7 @@ class ClientSession(
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
@ -351,6 +353,7 @@ class ClientSession(
|
||||
with responder:
|
||||
return responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
@override
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
@ -358,6 +361,7 @@ class ClientSession(
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
self._message_handler(req)
|
||||
|
||||
@override
|
||||
def _received_notification(self, notification: types.ServerNotification):
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@ -25,6 +25,7 @@ class ApiModeration(Moderation):
|
||||
name: str = "api"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -43,6 +44,7 @@ class ApiModeration(Moderation):
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
|
||||
@override
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
@ -59,6 +61,7 @@ class ApiModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
@override
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
|
||||
@ -8,6 +8,7 @@ class KeywordsModeration(Moderation):
|
||||
name: str = "keywords"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -28,6 +29,7 @@ class KeywordsModeration(Moderation):
|
||||
if len(keywords_row_len) > 100:
|
||||
raise ValueError("the number of rows for the keywords must be less than 100")
|
||||
|
||||
@override
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
@ -49,6 +51,7 @@ class KeywordsModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
@override
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
@ -9,6 +9,7 @@ class OpenAIModeration(Moderation):
|
||||
name: str = "openai_moderation"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -19,6 +20,7 @@ class OpenAIModeration(Moderation):
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
@override
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
@ -36,6 +38,7 @@ class OpenAIModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
@override
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
@ -11,6 +13,7 @@ class PluginDaemonError(Exception):
|
||||
def __init__(self, description: str):
|
||||
self.description = description
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
# returns the class name and description
|
||||
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"
|
||||
|
||||
@ -4,7 +4,7 @@ import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Literal, cast, overload
|
||||
from typing import IO, Any, Literal, cast, overload, override
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
@ -118,6 +118,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
self._provider_entities = None
|
||||
self._provider_entities_lock = Lock()
|
||||
|
||||
@override
|
||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
|
||||
if self._provider_entities is not None:
|
||||
return self._provider_entities
|
||||
@ -130,6 +131,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
|
||||
return self._provider_entities
|
||||
|
||||
@override
|
||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
provider_schema = self._get_provider_schema(provider)
|
||||
|
||||
@ -172,6 +174,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
@override
|
||||
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_provider_credentials(
|
||||
@ -182,6 +185,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@override
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
*,
|
||||
@ -201,6 +205,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_model_schema(
|
||||
self,
|
||||
*,
|
||||
@ -267,6 +272,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
return schema
|
||||
|
||||
@overload
|
||||
@override
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -281,6 +287,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
@override
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -294,6 +301,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
@override
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -330,6 +338,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
)
|
||||
|
||||
@overload
|
||||
@override
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
@ -344,6 +353,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
@override
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
@ -357,6 +367,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
@override
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
@ -396,6 +407,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
*,
|
||||
@ -422,6 +434,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
*,
|
||||
@ -443,6 +456,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
*,
|
||||
@ -464,6 +478,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
*,
|
||||
@ -483,6 +498,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_rerank(
|
||||
self,
|
||||
*,
|
||||
@ -508,6 +524,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
*,
|
||||
@ -533,6 +550,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_tts(
|
||||
self,
|
||||
*,
|
||||
@ -554,6 +572,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
*,
|
||||
@ -573,6 +592,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
language=language,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
*,
|
||||
@ -592,6 +612,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
file=file,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_moderation(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
@ -29,6 +29,7 @@ class Jieba(BaseKeyword):
|
||||
super().__init__(dataset)
|
||||
self._config = KeywordTableConfig()
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -48,6 +49,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
return self
|
||||
|
||||
@override
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -72,12 +74,14 @@ class Jieba(BaseKeyword):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
if keyword_table is None:
|
||||
return False
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -87,6 +91,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@override
|
||||
def search(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
@ -122,6 +127,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
return documents
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
|
||||
@ -2,7 +2,7 @@ import base64
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -72,15 +72,19 @@ class _LazyEmbeddings(Embeddings):
|
||||
self._real = CacheEmbedding(embedding_model)
|
||||
return self._real
|
||||
|
||||
@override
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._ensure().embed_documents(texts)
|
||||
|
||||
@override
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
return self._ensure().embed_multimodal_documents(multimodel_documents)
|
||||
|
||||
@override
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._ensure().embed_query(text)
|
||||
|
||||
@override
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
return self._ensure().embed_multimodal_query(multimodel_document)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user