From a7175198226a1e453451d625f138d01cbcaabf0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Tue, 17 Mar 2026 17:50:51 +0800 Subject: [PATCH] refactor(api): tighten phase 1 shared type contracts (#33453) --- api/AGENTS.md | 2 +- .../middleware/cache/redis_pubsub_config.py | 17 +-- api/dify_graph/variables/types.py | 4 +- api/extensions/ext_fastopenapi.py | 8 +- api/factories/variable_factory.py | 18 ++- api/libs/helper.py | 41 ++++-- api/libs/login.py | 8 +- api/libs/module_loading.py | 13 +- api/libs/oauth.py | 101 ++++++++++--- api/libs/oauth_data_source.py | 133 ++++++++++++------ api/models/trigger.py | 17 ++- api/models/workflow.py | 110 +++++++-------- api/pyrefly-local-excludes.txt | 15 -- ..._api_workflow_node_execution_repository.py | 22 ++- 14 files changed, 313 insertions(+), 196 deletions(-) diff --git a/api/AGENTS.md b/api/AGENTS.md index d43d2528b8..8e5d9f600d 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -78,7 +78,7 @@ class UserProfile(TypedDict): nickname: NotRequired[str] ``` -- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance: +- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance: ```python from datetime import datetime diff --git a/api/configs/middleware/cache/redis_pubsub_config.py b/api/configs/middleware/cache/redis_pubsub_config.py index d30831a0ec..0a166818b3 100644 --- a/api/configs/middleware/cache/redis_pubsub_config.py +++ b/api/configs/middleware/cache/redis_pubsub_config.py @@ -1,4 +1,4 @@ -from typing import Literal, Protocol +from typing import Literal, Protocol, cast from urllib.parse import quote_plus, urlunparse from pydantic import AliasChoices, Field @@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol): REDIS_PASSWORD: str | None REDIS_DB: int REDIS_USE_SSL: bool - REDIS_USE_SENTINEL: bool | None - REDIS_USE_CLUSTERS: bool -class RedisConfigDefaultsMixin: - def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults: - return self +def _redis_defaults(config: object) -> RedisConfigDefaults: + return cast(RedisConfigDefaults, config) -class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): +class RedisPubSubConfig(BaseSettings): """ Configuration settings for event transport between API and workers. @@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): ) def _build_default_pubsub_url(self) -> str: - defaults = self._redis_defaults() + defaults = _redis_defaults(self) if not defaults.REDIS_HOST or not defaults.REDIS_PORT: raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed") @@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): if userinfo: userinfo = f"{userinfo}@" - host = defaults.REDIS_HOST - port = defaults.REDIS_PORT db = defaults.REDIS_DB - netloc = f"{userinfo}{host}:{port}" + netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}" return urlunparse((scheme, netloc, f"/{db}", "", "", "")) @property diff --git a/api/dify_graph/variables/types.py b/api/dify_graph/variables/types.py index df8430de5d..53bf495a27 100644 --- a/api/dify_graph/variables/types.py +++ b/api/dify_graph/variables/types.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any from dify_graph.file.models import File if TYPE_CHECKING: - pass + from dify_graph.variables.segments import Segment class ArrayValidation(StrEnum): @@ -219,7 +219,7 @@ class SegmentType(StrEnum): return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) @staticmethod - def get_zero_value(t: SegmentType): + def get_zero_value(t: SegmentType) -> Segment: # Lazy import to avoid circular dependency from factories import variable_factory diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py index ab4d23a072..569203e974 100644 --- a/api/extensions/ext_fastopenapi.py +++ b/api/extensions/ext_fastopenapi.py @@ -1,3 +1,5 @@ +from typing import Protocol, cast + from fastopenapi.routers import FlaskRouter from flask_cors import CORS @@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS DOCS_PREFIX = "/fastopenapi" +class SupportsIncludeRouter(Protocol): + def include_router(self, router: object, *, prefix: str = "") -> None: ... + + def init_app(app: DifyApp) -> None: docs_enabled = dify_config.SWAGGER_UI_ENABLED docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None @@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None: _ = remote_files _ = setup - router.include_router(console_router, prefix="/console/api") + cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api") CORS( app, resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 255e5cde83..14a56bf4a2 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -55,7 +55,7 @@ class TypeMismatchError(Exception): # Define the constant -SEGMENT_TO_VARIABLE_MAP = { +SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = { ArrayAnySegment: ArrayAnyVariable, ArrayBooleanSegment: ArrayBooleanVariable, ArrayFileSegment: ArrayFileVariable, @@ -296,13 +296,11 @@ def segment_to_variable( raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return cast( - VariableBase, - variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=list(selector), - ), + return variable_class( + id=id, + name=name, + description=description, + value_type=segment.value_type, + value=segment.value, + selector=list(selector), ) diff --git a/api/libs/helper.py b/api/libs/helper.py index 6151eb0940..e7572cc025 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -32,6 +32,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _stream_with_request_context(response: object) -> Any: + """Bridge Flask's loosely-typed streaming helper without leaking casts into callers.""" + return cast(Any, stream_with_context)(response) + + def escape_like_pattern(pattern: str) -> str: """ Escape special characters in a string for safe use in SQL LIKE patterns. @@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str: return sha256(hash_text.encode()).hexdigest() -def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: - if isinstance(response, dict): +def compact_generate_response( + response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator, +) -> Response: + if isinstance(response, Mapping): return Response( response=json.dumps(jsonable_encoder(response)), status=200, content_type="application/json; charset=utf-8", ) else: + stream_response = response - def generate() -> Generator: - yield from response + def generate() -> Generator[str, None, None]: + yield from stream_response - return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") + return Response( + _stream_with_request_context(generate()), + status=200, + mimetype="text/event-stream", + ) -def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: +def length_prefixed_response( + magic_number: int, + response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator, +) -> Response: """ This function is used to return a response with a length prefix. Magic number is a one byte number that indicates the type of the response. @@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat # | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data return struct.pack(" Generator: - for chunk in response: + stream_response = response + + def generate() -> Generator[bytes, None, None]: + for chunk in stream_response: if isinstance(chunk, str): yield pack_response_with_length_prefix(chunk.encode("utf-8")) else: yield pack_response_with_length_prefix(chunk) - return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") + return Response( + _stream_with_request_context(generate()), + status=200, + mimetype="text/event-stream", + ) class TokenManager: diff --git a/api/libs/login.py b/api/libs/login.py index 69e2b58426..bd5cb5f30d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue] @wraps(func) def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue: if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: - pass - elif current_user is not None and not current_user.is_authenticated: + return current_app.ensure_sync(func)(*args, **kwargs) + + user = _get_user() + if user is None or not user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. - check_csrf_token(request, current_user.id) + check_csrf_token(request, user.id) return current_app.ensure_sync(func)(*args, **kwargs) return decorated_view diff --git a/api/libs/module_loading.py b/api/libs/module_loading.py index 9f74943433..7063a115b0 100644 --- a/api/libs/module_loading.py +++ b/api/libs/module_loading.py @@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py import sys from importlib import import_module +from typing import Any -def cached_import(module_path: str, class_name: str): +def cached_import(module_path: str, class_name: str) -> Any: """ Import a module and return the named attribute/class from it, with caching. @@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str): Returns: The imported attribute/class """ - if not ( - (module := sys.modules.get(module_path)) - and (spec := getattr(module, "__spec__", None)) - and getattr(spec, "_initializing", False) is False - ): + module = sys.modules.get(module_path) + spec = getattr(module, "__spec__", None) if module is not None else None + if module is None or getattr(spec, "_initializing", False): module = import_module(module_path) return getattr(module, class_name) -def import_string(dotted_path: str): +def import_string(dotted_path: str) -> Any: """ Import a dotted module path and return the attribute/class designated by the last name in the path. Raise ImportError if the import failed. diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 889a5a3248..efce13f6f1 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,7 +1,48 @@ +import sys import urllib.parse from dataclasses import dataclass +from typing import NotRequired import httpx +from pydantic import TypeAdapter + +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + +JsonObject = dict[str, object] +JsonObjectList = list[JsonObject] + +JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject) +JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList) + + +class AccessTokenResponse(TypedDict, total=False): + access_token: str + + +class GitHubEmailRecord(TypedDict, total=False): + email: str + primary: bool + + +class GitHubRawUserInfo(TypedDict): + id: int | str + login: str + name: NotRequired[str] + email: NotRequired[str] + + +class GoogleRawUserInfo(TypedDict): + sub: str + email: str + + +ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse) +GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo) +GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord]) +GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo) @dataclass @@ -11,26 +52,38 @@ class OAuthUserInfo: email: str +def _json_object(response: httpx.Response) -> JsonObject: + return JSON_OBJECT_ADAPTER.validate_python(response.json()) + + +def _json_list(response: httpx.Response) -> JsonObjectList: + return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json()) + + class OAuth: + client_id: str + client_secret: str + redirect_uri: str + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): self.client_id = client_id self.client_secret = client_secret self.redirect_uri = redirect_uri - def get_authorization_url(self): + def get_authorization_url(self, invite_token: str | None = None) -> str: raise NotImplementedError() - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> str: raise NotImplementedError() - def get_raw_user_info(self, token: str): + def get_raw_user_info(self, token: str) -> JsonObject: raise NotImplementedError() def get_user_info(self, token: str) -> OAuthUserInfo: raw_info = self.get_raw_user_info(token) return self._transform_user_info(raw_info) - def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: + def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: raise NotImplementedError() @@ -40,7 +93,7 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self, invite_token: str | None = None): + def get_authorization_url(self, invite_token: str | None = None) -> str: params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, @@ -50,7 +103,7 @@ class GitHubOAuth(OAuth): params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> str: data = { "client_id": self.client_id, "client_secret": self.client_secret, @@ -60,7 +113,7 @@ class GitHubOAuth(OAuth): headers = {"Accept": "application/json"} response = httpx.post(self._TOKEN_URL, data=data, headers=headers) - response_json = response.json() + response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response)) access_token = response_json.get("access_token") if not access_token: @@ -68,23 +121,24 @@ class GitHubOAuth(OAuth): return access_token - def get_raw_user_info(self, token: str): + def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"token {token}"} response = httpx.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() - user_info = response.json() + user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response)) email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) - email_info = email_response.json() - primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) + email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) + primary_email = next((email for email in email_info if email.get("primary") is True), None) - return {**user_info, "email": primary_email.get("email", "")} + return {**user_info, "email": primary_email.get("email", "") if primary_email else ""} - def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - email = raw_info.get("email") + def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: + payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info) + email = payload.get("email") if not email: - email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" - return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email) + email = f"{payload['id']}+{payload['login']}@users.noreply.github.com" + return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email) class GoogleOAuth(OAuth): @@ -92,7 +146,7 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self, invite_token: str | None = None): + def get_authorization_url(self, invite_token: str | None = None) -> str: params = { "client_id": self.client_id, "response_type": "code", @@ -103,7 +157,7 @@ class GoogleOAuth(OAuth): params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> str: data = { "client_id": self.client_id, "client_secret": self.client_secret, @@ -114,7 +168,7 @@ class GoogleOAuth(OAuth): headers = {"Accept": "application/json"} response = httpx.post(self._TOKEN_URL, data=data, headers=headers) - response_json = response.json() + response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response)) access_token = response_json.get("access_token") if not access_token: @@ -122,11 +176,12 @@ class GoogleOAuth(OAuth): return access_token - def get_raw_user_info(self, token: str): + def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"Bearer {token}"} response = httpx.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() - return response.json() + return _json_object(response) - def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) + def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: + payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info) + return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index ae0ae3bcb6..d5dc35ac97 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,25 +1,57 @@ +import sys import urllib.parse -from typing import Any +from typing import Any, Literal import httpx from flask_login import current_user +from pydantic import TypeAdapter from sqlalchemy import select from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +class NotionPageSummary(TypedDict): + page_id: str + page_name: str + page_icon: dict[str, str] | None + parent_id: str + type: Literal["page", "database"] + + +class NotionSourceInfo(TypedDict): + workspace_name: str | None + workspace_icon: str | None + workspace_id: str | None + pages: list[NotionPageSummary] + total: int + + +SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object]) +NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo) +NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary) + class OAuthDataSource: + client_id: str + client_secret: str + redirect_uri: str + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): self.client_id = client_id self.client_secret = client_secret self.redirect_uri = redirect_uri - def get_authorization_url(self): + def get_authorization_url(self) -> str: raise NotImplementedError() - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> None: raise NotImplementedError() @@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource): _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" - def get_authorization_url(self): + def get_authorization_url(self) -> str: params = { "client_id": self.client_id, "response_type": "code", @@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource): } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> None: data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} headers = {"Accept": "application/json"} auth = (self.client_id, self.client_secret) @@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource): workspace_id = response_json.get("workspace_id") # get all authorized pages pages = self.get_authorized_pages(access_token) - source_info = { - "workspace_name": workspace_name, - "workspace_icon": workspace_icon, - "workspace_id": workspace_id, - "pages": pages, - "total": len(pages), - } + source_info = self._build_source_info( + workspace_name=workspace_name, + workspace_icon=workspace_icon, + workspace_id=workspace_id, + pages=pages, + ) # save data source binding data_source_binding = db.session.scalar( select(DataSourceOauthBinding).where( @@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource): ) ) if data_source_binding: - data_source_binding.source_info = source_info + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now() db.session.commit() @@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource): new_data_source_binding = DataSourceOauthBinding( tenant_id=current_user.current_tenant_id, access_token=access_token, - source_info=source_info, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() - def save_internal_access_token(self, access_token: str): + def save_internal_access_token(self, access_token: str) -> None: workspace_name = self.notion_workspace_name(access_token) workspace_icon = None workspace_id = current_user.current_tenant_id # get all authorized pages pages = self.get_authorized_pages(access_token) - source_info = { - "workspace_name": workspace_name, - "workspace_icon": workspace_icon, - "workspace_id": workspace_id, - "pages": pages, - "total": len(pages), - } + source_info = self._build_source_info( + workspace_name=workspace_name, + workspace_icon=workspace_icon, + workspace_id=workspace_id, + pages=pages, + ) # save data source binding data_source_binding = db.session.scalar( select(DataSourceOauthBinding).where( @@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource): ) ) if data_source_binding: - data_source_binding.source_info = source_info + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now() db.session.commit() @@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource): new_data_source_binding = DataSourceOauthBinding( tenant_id=current_user.current_tenant_id, access_token=access_token, - source_info=source_info, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() - def sync_data_source(self, binding_id: str): + def sync_data_source(self, binding_id: str) -> None: # save data source binding data_source_binding = db.session.scalar( select(DataSourceOauthBinding).where( @@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: # get all authorized pages pages = self.get_authorized_pages(data_source_binding.access_token) - source_info = data_source_binding.source_info - new_source_info = { - "workspace_name": source_info["workspace_name"], - "workspace_icon": source_info["workspace_icon"], - "workspace_id": source_info["workspace_id"], - "pages": pages, - "total": len(pages), - } - data_source_binding.source_info = new_source_info + source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) + new_source_info = self._build_source_info( + workspace_name=source_info["workspace_name"], + workspace_icon=source_info["workspace_icon"], + workspace_id=source_info["workspace_id"], + pages=pages, + ) + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now() db.session.commit() else: raise ValueError("Data source binding not found") - def get_authorized_pages(self, access_token: str): - pages = [] + def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]: + pages: list[NotionPageSummary] = [] page_results = self.notion_page_search(access_token) database_results = self.notion_database_search(access_token) # get page detail @@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource): "parent_id": parent_id, "type": "page", } - pages.append(page) + pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page)) # get database detail for database_result in database_results: page_id = database_result["id"] @@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource): "parent_id": parent_id, "type": "database", } - pages.append(page) + pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page)) return pages - def notion_page_search(self, access_token: str): - results = [] + def notion_page_search(self, access_token: str) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] next_cursor = None has_more = True @@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource): return results - def notion_block_parent_page_id(self, access_token: str, block_id: str): + def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str: headers = { "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", @@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource): return self.notion_block_parent_page_id(access_token, parent[parent_type]) return parent[parent_type] - def notion_workspace_name(self, access_token: str): + def notion_workspace_name(self, access_token: str) -> str: headers = { "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", @@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource): return user_info["workspace_name"] return "workspace" - def notion_database_search(self, access_token: str): - results = [] + def notion_database_search(self, access_token: str) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] next_cursor = None has_more = True @@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource): next_cursor = response_json.get("next_cursor", None) return results + + @staticmethod + def _build_source_info( + *, + workspace_name: str | None, + workspace_icon: str | None, + workspace_id: str | None, + pages: list[NotionPageSummary], + ) -> NotionSourceInfo: + return { + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), + } diff --git a/api/models/trigger.py b/api/models/trigger.py index bb003a71b1..627b854060 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -23,6 +23,9 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr from .model import Account from .types import EnumText, LongText, StringUUID +TriggerJsonObject = dict[str, object] +TriggerCredentials = dict[str, str] + class WorkflowTriggerLogDict(TypedDict): id: str @@ -89,10 +92,14 @@ class TriggerSubscription(TypeBase): String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)" ) endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint") - parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON") - properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON") + parameters: Mapped[TriggerJsonObject] = mapped_column( + sa.JSON, nullable=False, comment="Subscription parameters JSON" + ) + properties: Mapped[TriggerJsonObject] = mapped_column( + sa.JSON, nullable=False, comment="Subscription properties JSON" + ) - credentials: Mapped[dict[str, Any]] = mapped_column( + credentials: Mapped[TriggerCredentials] = mapped_column( sa.JSON, nullable=False, comment="Subscription credentials JSON" ) credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key") @@ -200,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase): ) @property - def oauth_params(self) -> Mapping[str, Any]: - return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}")) + def oauth_params(self) -> Mapping[str, object]: + return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}")) class WorkflowTriggerLog(TypeBase): diff --git a/api/models/workflow.py b/api/models/workflow.py index 95bbc9eaae..f2e8305758 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -19,7 +19,7 @@ from sqlalchemy import ( orm, select, ) -from sqlalchemy.orm import Mapped, declared_attr, mapped_column +from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE @@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus from dify_graph.file.constants import maybe_file_object from dify_graph.file.models import File from dify_graph.variables import utils as variable_utils -from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable +from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -59,6 +59,9 @@ from .types import EnumText, LongText, StringUUID logger = logging.getLogger(__name__) +SerializedWorkflowValue = dict[str, Any] +SerializedWorkflowVariables = dict[str, SerializedWorkflowValue] + class WorkflowContentDict(TypedDict): graph: Mapping[str, Any] @@ -405,7 +408,7 @@ class Workflow(Base): # bug def rag_pipeline_user_input_form(self) -> list: # get user_input_form from start node - variables: list[Any] = self.rag_pipeline_variables + variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables return variables @@ -448,17 +451,13 @@ class Workflow(Base): # bug def environment_variables( self, ) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: - # TODO: find some way to init `self._environment_variables` when instance created. - if self._environment_variables is None: - self._environment_variables = "{}" - # Use workflow.tenant_id to avoid relying on request user in background threads tenant_id = self.tenant_id if not tenant_id: return [] - environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}") + environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}")) results = [ variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values() ] @@ -536,11 +535,7 @@ class Workflow(Base): # bug @property def conversation_variables(self) -> Sequence[VariableBase]: - # TODO: find some way to init `self._conversation_variables` when instance created. - if self._conversation_variables is None: - self._conversation_variables = "{}" - - variables_dict: dict[str, Any] = json.loads(self._conversation_variables) + variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}")) results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] return results @@ -552,19 +547,20 @@ class Workflow(Base): # bug ) @property - def rag_pipeline_variables(self) -> list[dict]: - # TODO: find some way to init `self._conversation_variables` when instance created. - if self._rag_pipeline_variables is None: - self._rag_pipeline_variables = "{}" - - variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) - results = list(variables_dict.values()) - return results + def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]: + variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}")) + return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()] @rag_pipeline_variables.setter - def rag_pipeline_variables(self, values: list[dict]) -> None: + def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None: self._rag_pipeline_variables = json.dumps( - {item["variable"]: item for item in values}, + { + rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json") + for rag_pipeline_variable in ( + item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item) + for item in values + ) + }, ensure_ascii=False, ) @@ -802,44 +798,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo __tablename__ = "workflow_node_executions" - @declared_attr.directive - @classmethod - def __table_args__(cls) -> Any: - return ( - PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), - Index( - "workflow_node_execution_workflow_run_id_idx", - "workflow_run_id", - ), - Index( - "workflow_node_execution_node_run_idx", - "tenant_id", - "app_id", - "workflow_id", - "triggered_from", - "node_id", - ), - Index( - "workflow_node_execution_id_idx", - "tenant_id", - "app_id", - "workflow_id", - "triggered_from", - "node_execution_id", - ), - Index( - # The first argument is the index name, - # which we leave as `None`` to allow auto-generation by the ORM. - None, - cls.tenant_id, - cls.workflow_id, - cls.node_id, - # MyPy may flag the following line because it doesn't recognize that - # the `declared_attr` decorator passes the receiving class as the first - # argument to this method, allowing us to reference class attributes. - cls.created_at.desc(), - ), - ) + __table_args__ = ( + PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), + Index( + "workflow_node_execution_workflow_run_id_idx", + "workflow_run_id", + ), + Index( + "workflow_node_execution_node_run_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_id", + ), + Index( + "workflow_node_execution_id_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_execution_id", + ), + Index( + None, + "tenant_id", + "workflow_id", + "node_id", + sa.desc("created_at"), + ), + ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID) diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index c044824a82..9a76de1927 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -1,4 +1,3 @@ -configs/middleware/cache/redis_pubsub_config.py controllers/console/app/annotation.py controllers/console/app/app.py controllers/console/app/app_import.py @@ -138,8 +137,6 @@ dify_graph/nodes/trigger_webhook/node.py dify_graph/nodes/variable_aggregator/variable_aggregator_node.py dify_graph/nodes/variable_assigner/v1/node.py dify_graph/nodes/variable_assigner/v2/node.py -dify_graph/variables/types.py -extensions/ext_fastopenapi.py extensions/logstore/repositories/logstore_api_workflow_run_repository.py extensions/otel/instrumentation.py extensions/otel/runtime.py @@ -156,19 +153,7 @@ extensions/storage/oracle_oci_storage.py extensions/storage/supabase_storage.py extensions/storage/tencent_cos_storage.py extensions/storage/volcengine_tos_storage.py -factories/variable_factory.py -libs/external_api.py libs/gmpy2_pkcs10aep_cipher.py -libs/helper.py -libs/login.py -libs/module_loading.py -libs/oauth.py -libs/oauth_data_source.py -models/trigger.py -models/workflow.py -repositories/sqlalchemy_api_workflow_node_execution_repository.py -repositories/sqlalchemy_api_workflow_run_repository.py -repositories/sqlalchemy_execution_extra_content_repository.py schedule/queue_monitor_task.py services/account_service.py services/audio_service.py diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 2266c2e646..77e40fc6fc 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. import json from collections.abc import Sequence from datetime import datetime -from typing import cast +from typing import Protocol, cast from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult @@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import ( ) +class _WorkflowNodeExecutionSnapshotRow(Protocol): + id: str + node_execution_id: str | None + node_id: str + node_type: str + title: str + index: int + status: WorkflowNodeExecutionStatus + elapsed_time: float | None + created_at: datetime + finished_at: datetime | None + execution_metadata: str | None + + class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): """ SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository. @@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut - Thread-safe database operations using session-per-request pattern """ + _session_maker: sessionmaker[Session] + def __init__(self, session_maker: sessionmaker[Session]): """ Initialize the repository with a sessionmaker. @@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut ) with self._session_maker() as session: - rows = session.execute(stmt).all() + rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all()) return [self._row_to_snapshot(row) for row in rows] @staticmethod - def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot: + def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot: metadata: dict[str, object] = {} execution_metadata = getattr(row, "execution_metadata", None) if execution_metadata: