mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 13:47:52 +08:00
Compare commits
2 Commits
feat/mcp-t
...
fix/worksp
| Author | SHA1 | Date | |
|---|---|---|---|
| 58da51c1ba | |||
| 9d093f71ed |
@ -209,11 +209,6 @@ class MCPProviderBasePayload(BaseModel):
|
||||
configuration: dict[str, Any] | None = Field(default_factory=dict)
|
||||
headers: dict[str, Any] | None = Field(default_factory=dict)
|
||||
authentication: dict[str, Any] | None = Field(default_factory=dict)
|
||||
# M3 — user-identity forwarding (M2 backend already supports these on the
|
||||
# service layer). Defaults preserve pre-M3 behavior for clients that don't
|
||||
# send the fields yet.
|
||||
forward_user_identity: bool = False
|
||||
identity_mode: Literal["off", "idp_token"] = "off"
|
||||
|
||||
|
||||
class MCPProviderCreatePayload(MCPProviderBasePayload):
|
||||
@ -990,8 +985,6 @@ class ToolProviderMCPApi(Resource):
|
||||
headers=payload.headers or {},
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
forward_user_identity=payload.forward_user_identity,
|
||||
identity_mode=payload.identity_mode,
|
||||
)
|
||||
|
||||
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||
@ -1059,8 +1052,6 @@ class ToolProviderMCPApi(Resource):
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
forward_user_identity=payload.forward_user_identity,
|
||||
identity_mode=payload.identity_mode,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -29,7 +29,7 @@ from controllers.console.wraps import (
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import TimestampField, to_timestamp
|
||||
from libs.helper import TimestampField, dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
@ -56,6 +56,11 @@ class WorkspaceCustomConfigPayload(BaseModel):
|
||||
replace_webapp_logo: str | None = None
|
||||
|
||||
|
||||
class WorkspaceCustomConfigResponse(ResponseModel):
|
||||
remove_webapp_brand: bool | None = None
|
||||
replace_webapp_logo: str | None = None
|
||||
|
||||
|
||||
class WorkspaceInfoPayload(BaseModel):
|
||||
name: str
|
||||
|
||||
@ -69,7 +74,7 @@ class TenantInfoResponse(ResponseModel):
|
||||
role: str | None = None
|
||||
in_trial: bool | None = None
|
||||
trial_end_reason: str | None = None
|
||||
custom_config: dict | None = None
|
||||
custom_config: WorkspaceCustomConfigResponse | None = None
|
||||
trial_credits: int | None = None
|
||||
trial_credits_used: int | None = None
|
||||
next_credit_reset_date: int | None = None
|
||||
@ -101,9 +106,13 @@ register_schema_models(
|
||||
SwitchWorkspacePayload,
|
||||
WorkspaceCustomConfigPayload,
|
||||
WorkspaceInfoPayload,
|
||||
TenantInfoResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, WorkspacePermissionResponse)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
TenantInfoResponse,
|
||||
WorkspaceCustomConfigResponse,
|
||||
WorkspacePermissionResponse,
|
||||
)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
@ -238,13 +247,7 @@ class TenantApi(Resource):
|
||||
else:
|
||||
raise Unauthorized("workspace is archived")
|
||||
|
||||
return (
|
||||
TenantInfoResponse.model_validate(
|
||||
WorkspaceService.get_tenant_info(tenant),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
return dump_response(TenantInfoResponse, WorkspaceService.get_tenant_info(tenant)), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -76,14 +76,6 @@ class MCPProviderEntity(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# M2 — user-identity forwarding. When forward_user_identity is True AND
|
||||
# identity_mode is "idp_token", the MCP tool runtime asks dify-enterprise
|
||||
# to mint a fresh SSO id_token for the calling user and stamps it on the
|
||||
# outbound MCP request as `Authorization: Bearer <token>`. Defaults keep
|
||||
# pre-M2 providers unchanged (no forwarding).
|
||||
forward_user_identity: bool = False
|
||||
identity_mode: Literal["off", "idp_token"] = "off"
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
|
||||
"""Create entity from database model with decryption"""
|
||||
@ -104,8 +96,6 @@ class MCPProviderEntity(BaseModel):
|
||||
icon=db_provider.icon or "",
|
||||
created_at=db_provider.created_at,
|
||||
updated_at=db_provider.updated_at,
|
||||
forward_user_identity=db_provider.forward_user_identity,
|
||||
identity_mode=db_provider.identity_mode, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@property
|
||||
@ -180,8 +170,6 @@ class MCPProviderEntity(BaseModel):
|
||||
"updated_at": int(self.updated_at.timestamp()),
|
||||
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
|
||||
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||
"forward_user_identity": self.forward_user_identity,
|
||||
"identity_mode": self.identity_mode,
|
||||
}
|
||||
|
||||
# Add configuration
|
||||
|
||||
@ -54,12 +54,6 @@ class ToolProviderApiEntity(BaseModel):
|
||||
configuration: MCPConfiguration | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
# M3 — user-identity forwarding flags. Round-tripped through the console
|
||||
# API so the create/edit modal can hydrate the toggle state.
|
||||
forward_user_identity: bool = Field(
|
||||
default=False, description="Whether Dify forwards the calling user's SSO identity to this MCP server"
|
||||
)
|
||||
identity_mode: str = Field(default="off", description="Identity-forwarding mechanism: 'off' or 'idp_token'")
|
||||
# Workflow
|
||||
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")
|
||||
|
||||
@ -98,10 +92,6 @@ class ToolProviderApiEntity(BaseModel):
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
# M3 — forwarding flags. Always emit (False/"off" are valid
|
||||
# values that the UI must hydrate, not skip).
|
||||
optional_fields["forward_user_identity"] = self.forward_user_identity
|
||||
optional_fields["identity_mode"] = self.identity_mode
|
||||
case ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
case _:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal, Self
|
||||
from typing import Any, Self
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
@ -28,8 +28,6 @@ class MCPToolProviderController(ToolProviderController):
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
forward_user_identity: bool = False,
|
||||
identity_mode: Literal["off", "idp_token"] = "off",
|
||||
):
|
||||
super().__init__(entity)
|
||||
self.entity: ToolProviderEntityWithPlugin = entity
|
||||
@ -39,8 +37,6 @@ class MCPToolProviderController(ToolProviderController):
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.forward_user_identity = forward_user_identity
|
||||
self.identity_mode: Literal["off", "idp_token"] = identity_mode
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
@ -109,8 +105,6 @@ class MCPToolProviderController(ToolProviderController):
|
||||
headers=entity.headers,
|
||||
timeout=entity.timeout,
|
||||
sse_read_timeout=entity.sse_read_timeout,
|
||||
forward_user_identity=entity.forward_user_identity,
|
||||
identity_mode=entity.identity_mode,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
@ -140,8 +134,6 @@ class MCPToolProviderController(ToolProviderController):
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
forward_user_identity=self.forward_user_identity,
|
||||
identity_mode=self.identity_mode,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]:
|
||||
@ -159,8 +151,6 @@ class MCPToolProviderController(ToolProviderController):
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
forward_user_identity=self.forward_user_identity,
|
||||
identity_mode=self.identity_mode,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
||||
|
||||
@ -4,7 +4,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
@ -38,8 +38,6 @@ class MCPTool(Tool):
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
forward_user_identity: bool = False,
|
||||
identity_mode: Literal["off", "idp_token"] = "off",
|
||||
):
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
@ -49,8 +47,6 @@ class MCPTool(Tool):
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.forward_user_identity = forward_user_identity
|
||||
self.identity_mode: Literal["off", "idp_token"] = identity_mode
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
@ -64,7 +60,7 @@ class MCPTool(Tool):
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters, user_id=user_id, app_id=app_id)
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
|
||||
# Extract usage metadata from MCP protocol's _meta field
|
||||
self._latest_usage = self._derive_usage_from_result(result)
|
||||
@ -238,8 +234,6 @@ class MCPTool(Tool):
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
forward_user_identity=self.forward_user_identity,
|
||||
identity_mode=self.identity_mode,
|
||||
)
|
||||
|
||||
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||
@ -252,12 +246,7 @@ class MCPTool(Tool):
|
||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||
}
|
||||
|
||||
def invoke_remote_mcp_tool(
|
||||
self,
|
||||
tool_parameters: dict[str, Any],
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
) -> CallToolResult:
|
||||
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
|
||||
headers = self.headers.copy() if self.headers else {}
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
@ -282,14 +271,6 @@ class MCPTool(Tool):
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
# User-identity forwarding: if enabled on this provider, ask the
|
||||
# enterprise side to mint a fresh SSO id_token (audience-scoped to
|
||||
# the MCP server's URL per RFC 8707) and stamp it as Authorization.
|
||||
# This OVERRIDES any Authorization already on the request — the
|
||||
# forwarded identity is what the MCP server should trust.
|
||||
if self.forward_user_identity and self.identity_mode == "idp_token" and user_id:
|
||||
self._inject_forwarded_identity(headers, user_id=user_id, app_id=app_id, audience=server_url)
|
||||
|
||||
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||
try:
|
||||
@ -305,31 +286,3 @@ class MCPTool(Tool):
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
def _inject_forwarded_identity(
|
||||
self,
|
||||
headers: dict[str, str],
|
||||
*,
|
||||
user_id: str,
|
||||
app_id: str | None,
|
||||
audience: str,
|
||||
) -> None:
|
||||
"""Call the enterprise IssueMCPToken endpoint and stamp Authorization.
|
||||
|
||||
Errors are surfaced as ToolInvokeError so the workflow halts with a
|
||||
clear message instead of silently dropping identity and hitting the
|
||||
MCP server unauthenticated.
|
||||
"""
|
||||
from services.enterprise.base import MCPTokenError
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
try:
|
||||
token, _expires_at = EnterpriseService.issue_mcp_token(
|
||||
user_id=user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=app_id,
|
||||
audience=audience,
|
||||
)
|
||||
except MCPTokenError as e:
|
||||
raise ToolInvokeError(f"Failed to obtain forwarded identity token: {e}") from e
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
"""add identity mode to mcp tool provider
|
||||
|
||||
Revision ID: 3df4dbcc1e21
|
||||
Revises: 7885bd53f9a9
|
||||
Create Date: 2026-05-29 15:00:00.000000
|
||||
|
||||
Adds two columns to `tool_mcp_providers` that drive the M2 MCP user-identity
|
||||
forwarding feature:
|
||||
|
||||
* `forward_user_identity` (bool, default false) — master switch per provider.
|
||||
* `identity_mode` (string, default "off") — which forwarding mechanism to use:
|
||||
"off" — no header forwarded (default; pre-M2 behaviour).
|
||||
"idp_token" — call dify-enterprise /inner/api/mcp/issue-token, stamp
|
||||
the returned id_token on the outbound MCP request as
|
||||
`Authorization: Bearer <token>`.
|
||||
|
||||
The columns are filled with safe defaults for existing rows so older providers
|
||||
keep their current behaviour (no identity forwarding) until an admin opts in.
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3df4dbcc1e21"
|
||||
down_revision = "7885bd53f9a9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"tool_mcp_providers",
|
||||
sa.Column(
|
||||
"forward_user_identity",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"tool_mcp_providers",
|
||||
sa.Column(
|
||||
"identity_mode",
|
||||
sa.String(length=32),
|
||||
nullable=False,
|
||||
server_default=sa.text("'off'"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("tool_mcp_providers", "identity_mode")
|
||||
op.drop_column("tool_mcp_providers", "forward_user_identity")
|
||||
@ -343,21 +343,6 @@ class MCPToolProvider(TypeBase):
|
||||
# encrypted headers for MCP server requests
|
||||
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
|
||||
# M2 (MCP user-identity forwarding) — master switch per provider. When True
|
||||
# AND identity_mode is "idp_token", workflows that invoke tools on this
|
||||
# provider will have the caller's SSO id_token stamped on the outbound
|
||||
# request as `Authorization: Bearer …`. Off by default so existing
|
||||
# providers retain pre-M2 behaviour.
|
||||
forward_user_identity: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false"), default=False
|
||||
)
|
||||
# M2 — which identity-forwarding mechanism to use. Reserved values:
|
||||
# "off" — no forwarding (default).
|
||||
# "idp_token" — forward a Bearer id_token minted by dify-enterprise.
|
||||
identity_mode: Mapped[str] = mapped_column(
|
||||
sa.String(32), nullable=False, server_default=sa.text("'off'"), default="off"
|
||||
)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
|
||||
|
||||
@ -13761,12 +13761,10 @@ Enum class for large language model mode.
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| authentication | object | | No |
|
||||
| configuration | object | | No |
|
||||
| forward_user_identity | boolean | | No |
|
||||
| headers | object | | No |
|
||||
| icon | string | | Yes |
|
||||
| icon_background | string | | No |
|
||||
| icon_type | string | | Yes |
|
||||
| identity_mode | string | *Enum:* `"idp_token"`, `"off"` | No |
|
||||
| name | string | | Yes |
|
||||
| server_identifier | string | | Yes |
|
||||
| server_url | string | | Yes |
|
||||
@ -13783,12 +13781,10 @@ Enum class for large language model mode.
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| authentication | object | | No |
|
||||
| configuration | object | | No |
|
||||
| forward_user_identity | boolean | | No |
|
||||
| headers | object | | No |
|
||||
| icon | string | | Yes |
|
||||
| icon_background | string | | No |
|
||||
| icon_type | string | | Yes |
|
||||
| identity_mode | string | *Enum:* `"idp_token"`, `"off"` | No |
|
||||
| name | string | | Yes |
|
||||
| provider_id | string | | Yes |
|
||||
| server_identifier | string | | Yes |
|
||||
@ -15211,7 +15207,7 @@ Tag type
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| created_at | integer | | No |
|
||||
| custom_config | object | | No |
|
||||
| custom_config | [WorkspaceCustomConfigResponse](#workspacecustomconfigresponse) | | No |
|
||||
| id | string | | Yes |
|
||||
| in_trial | boolean | | No |
|
||||
| name | string | | No |
|
||||
@ -16334,6 +16330,13 @@ Workflow tool configuration
|
||||
| remove_webapp_brand | boolean | | No |
|
||||
| replace_webapp_logo | string | | No |
|
||||
|
||||
#### WorkspaceCustomConfigResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| remove_webapp_brand | boolean | | No |
|
||||
| replace_webapp_logo | string | | No |
|
||||
|
||||
#### WorkspaceInfoPayload
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
|
||||
@ -12,34 +12,8 @@ from services.errors.enterprise import (
|
||||
EnterpriseAPIForbiddenError,
|
||||
EnterpriseAPINotFoundError,
|
||||
EnterpriseAPIUnauthorizedError,
|
||||
EnterpriseServiceError,
|
||||
)
|
||||
|
||||
|
||||
# M2 — IssueMCPToken specific errors. Co-located here (rather than in
|
||||
# services/errors/enterprise.py) because services.enterprise.base is part of
|
||||
# the leaf-mounted file set the local dev override applies; the errors module
|
||||
# stays at the EE image's baked-in version.
|
||||
class MCPTokenError(EnterpriseServiceError):
|
||||
"""Generic failure of the IssueMCPToken RPC."""
|
||||
|
||||
|
||||
class MCPNoRefreshTokenError(MCPTokenError):
|
||||
"""The user has no stored SSO refresh_token on the enterprise side.
|
||||
The workflow should ask them to re-authenticate."""
|
||||
|
||||
def __init__(self, description: str = ""):
|
||||
super().__init__(description, status_code=428)
|
||||
|
||||
|
||||
class MCPIdentityRefreshError(MCPTokenError):
|
||||
"""The enterprise side tried to refresh the user's SSO refresh_token
|
||||
against the IdP and failed (revoked/expired/IdP error)."""
|
||||
|
||||
def __init__(self, description: str = ""):
|
||||
super().__init__(description, status_code=401)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@ -11,16 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.enterprise.base import (
|
||||
EnterpriseRequest,
|
||||
MCPIdentityRefreshError,
|
||||
MCPNoRefreshTokenError,
|
||||
MCPTokenError,
|
||||
)
|
||||
from services.errors.enterprise import (
|
||||
EnterpriseAPIError,
|
||||
EnterpriseAPIUnauthorizedError,
|
||||
)
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.feature_service import LicenseStatus
|
||||
@ -130,62 +121,6 @@ class EnterpriseService:
|
||||
def get_workspace_info(cls, tenant_id: str):
|
||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||
|
||||
@classmethod
|
||||
def issue_mcp_token(
|
||||
cls,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
app_id: str | None,
|
||||
audience: str,
|
||||
) -> tuple[str, int]:
|
||||
"""Mint a short-lived SSO id_token (or OAuth2 access_token) representing
|
||||
the calling Dify user, audience-scoped to the given MCP server identifier.
|
||||
|
||||
Used by MCPTool.invoke_remote_mcp_tool to stamp `Authorization: Bearer
|
||||
<token>` on outbound MCP requests when the provider has
|
||||
forward_user_identity=True and identity_mode="idp_token".
|
||||
|
||||
Returns:
|
||||
(token, expires_at_unix_seconds)
|
||||
|
||||
Raises:
|
||||
MCPNoRefreshTokenError: user has no stored SSO refresh_token on the
|
||||
enterprise side; surface to the workflow as "please log in via SSO".
|
||||
MCPIdentityRefreshError: enterprise tried to refresh against the IdP
|
||||
and the IdP rejected (revoked/expired session).
|
||||
MCPTokenError: any other failure of the enterprise endpoint.
|
||||
"""
|
||||
try:
|
||||
response = EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/mcp/issue-token",
|
||||
json={
|
||||
"user_id": user_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id or "",
|
||||
"audience": audience,
|
||||
},
|
||||
)
|
||||
except EnterpriseAPIUnauthorizedError as e:
|
||||
# Enterprise side returns 401 when the IdP rejected the refresh.
|
||||
raise MCPIdentityRefreshError(str(e) or "identity refresh failed; please re-authenticate") from e
|
||||
except EnterpriseAPIError as e:
|
||||
# Map the 428 PreconditionRequired we emit on no-stored-refresh-token.
|
||||
if getattr(e, "status_code", None) == 428:
|
||||
raise MCPNoRefreshTokenError(
|
||||
str(e) or "user has no stored SSO refresh token; please re-authenticate"
|
||||
) from e
|
||||
raise MCPTokenError(f"issue_mcp_token failed: {e}") from e
|
||||
|
||||
if not isinstance(response, dict):
|
||||
raise MCPTokenError("invalid response shape from enterprise /mcp/issue-token")
|
||||
|
||||
token = response.get("token")
|
||||
expires_at = response.get("expires_at")
|
||||
if not token or not isinstance(token, str) or not isinstance(expires_at, int):
|
||||
raise MCPTokenError(f"missing token/expires_at in enterprise response: {response}")
|
||||
return token, expires_at
|
||||
|
||||
@classmethod
|
||||
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
|
||||
return EnterpriseRequest.send_request(
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@ -136,8 +136,6 @@ class MCPToolManageService:
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
forward_user_identity: bool = False,
|
||||
identity_mode: Literal["off", "idp_token"] = "off",
|
||||
) -> ToolProviderApiEntity:
|
||||
"""Create a new MCP provider."""
|
||||
# Validate URL format
|
||||
@ -173,8 +171,6 @@ class MCPToolManageService:
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
encrypted_headers=encrypted_headers,
|
||||
encrypted_credentials=encrypted_credentials,
|
||||
forward_user_identity=forward_user_identity,
|
||||
identity_mode=identity_mode,
|
||||
)
|
||||
|
||||
self._session.add(mcp_tool)
|
||||
@ -198,8 +194,6 @@ class MCPToolManageService:
|
||||
configuration: MCPConfiguration,
|
||||
authentication: MCPAuthentication | None = None,
|
||||
validation_result: ServerUrlValidationResult | None = None,
|
||||
forward_user_identity: bool | None = None,
|
||||
identity_mode: Literal["off", "idp_token"] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update an MCP provider.
|
||||
@ -261,14 +255,6 @@ class MCPToolManageService:
|
||||
if authentication and authentication.client_id:
|
||||
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
|
||||
|
||||
# Update user-identity forwarding settings if provided.
|
||||
# None means "leave unchanged" so this stays backwards-compatible
|
||||
# with existing callers that don't know about M2.
|
||||
if forward_user_identity is not None:
|
||||
mcp_provider.forward_user_identity = forward_user_identity
|
||||
if identity_mode is not None:
|
||||
mcp_provider.identity_mode = identity_mode
|
||||
|
||||
# Flush changes to database
|
||||
self._session.flush()
|
||||
|
||||
|
||||
@ -435,6 +435,23 @@ class TestTenantInfoResponse:
|
||||
assert payload["plan"] == "team"
|
||||
assert payload["created_at"] == int(created_at.timestamp())
|
||||
|
||||
def test_tenant_info_response_has_typed_custom_config(self):
|
||||
payload = TenantInfoResponse.model_validate(
|
||||
{
|
||||
"id": "t1",
|
||||
"custom_config": {
|
||||
"remove_webapp_brand": True,
|
||||
"replace_webapp_logo": "logo-file-id",
|
||||
"ignored": "value",
|
||||
},
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
assert payload["custom_config"] == {
|
||||
"remove_webapp_brand": True,
|
||||
"replace_webapp_logo": "logo-file-id",
|
||||
}
|
||||
|
||||
|
||||
class TestSwitchWorkspaceApi:
|
||||
def test_switch_success(self, app: Flask):
|
||||
|
||||
@ -4904,11 +4904,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"web/app/device/page.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"web/app/education-apply/hooks.ts": {
|
||||
"react/set-state-in-effect": {
|
||||
"count": 5
|
||||
|
||||
@ -4,16 +4,8 @@ import { oc } from '@orpc/contract'
|
||||
|
||||
import { zPostInfoResponse } from './zod.gen'
|
||||
|
||||
/**
|
||||
* Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.
|
||||
*
|
||||
* @deprecated
|
||||
*/
|
||||
export const post = oc
|
||||
.route({
|
||||
deprecated: true,
|
||||
description:
|
||||
'Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.',
|
||||
inputStructure: 'detailed',
|
||||
method: 'POST',
|
||||
operationId: 'postInfo',
|
||||
|
||||
@ -6,9 +6,7 @@ export type ClientOptions = {
|
||||
|
||||
export type TenantInfoResponse = {
|
||||
created_at?: number | null
|
||||
custom_config?: {
|
||||
[key: string]: unknown
|
||||
} | null
|
||||
custom_config?: WorkspaceCustomConfigResponse
|
||||
id: string
|
||||
in_trial?: boolean | null
|
||||
name?: string | null
|
||||
@ -21,6 +19,11 @@ export type TenantInfoResponse = {
|
||||
trial_end_reason?: string | null
|
||||
}
|
||||
|
||||
export type WorkspaceCustomConfigResponse = {
|
||||
remove_webapp_brand?: boolean | null
|
||||
replace_webapp_logo?: string | null
|
||||
}
|
||||
|
||||
export type PostInfoData = {
|
||||
body?: never
|
||||
path?: never
|
||||
|
||||
@ -2,12 +2,20 @@
|
||||
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* WorkspaceCustomConfigResponse
|
||||
*/
|
||||
export const zWorkspaceCustomConfigResponse = z.object({
|
||||
remove_webapp_brand: z.boolean().nullish(),
|
||||
replace_webapp_logo: z.string().nullish(),
|
||||
})
|
||||
|
||||
/**
|
||||
* TenantInfoResponse
|
||||
*/
|
||||
export const zTenantInfoResponse = z.object({
|
||||
created_at: z.int().nullish(),
|
||||
custom_config: z.record(z.string(), z.unknown()).nullish(),
|
||||
custom_config: zWorkspaceCustomConfigResponse.optional(),
|
||||
id: z.string(),
|
||||
in_trial: z.boolean().nullish(),
|
||||
name: z.string().nullish(),
|
||||
|
||||
@ -3682,16 +3682,8 @@ export const triggers = {
|
||||
get: get58,
|
||||
}
|
||||
|
||||
/**
|
||||
* Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.
|
||||
*
|
||||
* @deprecated
|
||||
*/
|
||||
export const post63 = oc
|
||||
.route({
|
||||
deprecated: true,
|
||||
description:
|
||||
'Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.',
|
||||
inputStructure: 'detailed',
|
||||
method: 'POST',
|
||||
operationId: 'postWorkspacesCurrent',
|
||||
|
||||
@ -6,9 +6,7 @@ export type ClientOptions = {
|
||||
|
||||
export type TenantInfoResponse = {
|
||||
created_at?: number | null
|
||||
custom_config?: {
|
||||
[key: string]: unknown
|
||||
} | null
|
||||
custom_config?: WorkspaceCustomConfigResponse
|
||||
id: string
|
||||
in_trial?: boolean | null
|
||||
name?: string | null
|
||||
@ -392,14 +390,12 @@ export type McpProviderCreatePayload = {
|
||||
configuration?: {
|
||||
[key: string]: unknown
|
||||
} | null
|
||||
forward_user_identity?: boolean
|
||||
headers?: {
|
||||
[key: string]: unknown
|
||||
} | null
|
||||
icon: string
|
||||
icon_background?: string
|
||||
icon_type: string
|
||||
identity_mode?: 'idp_token' | 'off'
|
||||
name: string
|
||||
server_identifier: string
|
||||
server_url: string
|
||||
@ -412,14 +408,12 @@ export type McpProviderUpdatePayload = {
|
||||
configuration?: {
|
||||
[key: string]: unknown
|
||||
} | null
|
||||
forward_user_identity?: boolean
|
||||
headers?: {
|
||||
[key: string]: unknown
|
||||
} | null
|
||||
icon: string
|
||||
icon_background?: string
|
||||
icon_type: string
|
||||
identity_mode?: 'idp_token' | 'off'
|
||||
name: string
|
||||
provider_id: string
|
||||
server_identifier: string
|
||||
@ -504,6 +498,11 @@ export type SwitchWorkspacePayload = {
|
||||
tenant_id: string
|
||||
}
|
||||
|
||||
export type WorkspaceCustomConfigResponse = {
|
||||
remove_webapp_brand?: boolean | null
|
||||
replace_webapp_logo?: string | null
|
||||
}
|
||||
|
||||
export type AccountWithRole = {
|
||||
avatar?: string | null
|
||||
created_at?: number | null
|
||||
|
||||
@ -2,24 +2,6 @@
|
||||
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* TenantInfoResponse
|
||||
*/
|
||||
export const zTenantInfoResponse = z.object({
|
||||
created_at: z.int().nullish(),
|
||||
custom_config: z.record(z.string(), z.unknown()).nullish(),
|
||||
id: z.string(),
|
||||
in_trial: z.boolean().nullish(),
|
||||
name: z.string().nullish(),
|
||||
next_credit_reset_date: z.int().nullish(),
|
||||
plan: z.string().nullish(),
|
||||
role: z.string().nullish(),
|
||||
status: z.string().nullish(),
|
||||
trial_credits: z.int().nullish(),
|
||||
trial_credits_used: z.int().nullish(),
|
||||
trial_end_reason: z.string().nullish(),
|
||||
})
|
||||
|
||||
/**
|
||||
* SimpleResultResponse
|
||||
*/
|
||||
@ -361,12 +343,10 @@ export const zMcpProviderDeletePayload = z.object({
|
||||
export const zMcpProviderCreatePayload = z.object({
|
||||
authentication: z.record(z.string(), z.unknown()).nullish(),
|
||||
configuration: z.record(z.string(), z.unknown()).nullish(),
|
||||
forward_user_identity: z.boolean().optional().default(false),
|
||||
headers: z.record(z.string(), z.unknown()).nullish(),
|
||||
icon: z.string(),
|
||||
icon_background: z.string().optional().default(''),
|
||||
icon_type: z.string(),
|
||||
identity_mode: z.enum(['idp_token', 'off']).optional().default('off'),
|
||||
name: z.string(),
|
||||
server_identifier: z.string(),
|
||||
server_url: z.string(),
|
||||
@ -378,12 +358,10 @@ export const zMcpProviderCreatePayload = z.object({
|
||||
export const zMcpProviderUpdatePayload = z.object({
|
||||
authentication: z.record(z.string(), z.unknown()).nullish(),
|
||||
configuration: z.record(z.string(), z.unknown()).nullish(),
|
||||
forward_user_identity: z.boolean().optional().default(false),
|
||||
headers: z.record(z.string(), z.unknown()).nullish(),
|
||||
icon: z.string(),
|
||||
icon_background: z.string().optional().default(''),
|
||||
icon_type: z.string(),
|
||||
identity_mode: z.enum(['idp_token', 'off']).optional().default('off'),
|
||||
name: z.string(),
|
||||
provider_id: z.string(),
|
||||
server_identifier: z.string(),
|
||||
@ -459,6 +437,32 @@ export const zSwitchWorkspacePayload = z.object({
|
||||
tenant_id: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* WorkspaceCustomConfigResponse
|
||||
*/
|
||||
export const zWorkspaceCustomConfigResponse = z.object({
|
||||
remove_webapp_brand: z.boolean().nullish(),
|
||||
replace_webapp_logo: z.string().nullish(),
|
||||
})
|
||||
|
||||
/**
|
||||
* TenantInfoResponse
|
||||
*/
|
||||
export const zTenantInfoResponse = z.object({
|
||||
created_at: z.int().nullish(),
|
||||
custom_config: zWorkspaceCustomConfigResponse.optional(),
|
||||
id: z.string(),
|
||||
in_trial: z.boolean().nullish(),
|
||||
name: z.string().nullish(),
|
||||
next_credit_reset_date: z.int().nullish(),
|
||||
plan: z.string().nullish(),
|
||||
role: z.string().nullish(),
|
||||
status: z.string().nullish(),
|
||||
trial_credits: z.int().nullish(),
|
||||
trial_credits_used: z.int().nullish(),
|
||||
trial_end_reason: z.string().nullish(),
|
||||
})
|
||||
|
||||
/**
|
||||
* AccountWithRole
|
||||
*/
|
||||
|
||||
@ -131,14 +131,6 @@ vi.mock('@/service/use-common', () => ({
|
||||
],
|
||||
},
|
||||
}),
|
||||
useCurrentWorkspace: () => ({
|
||||
data: {
|
||||
trial_credits: 1000,
|
||||
trial_credits_used: 100,
|
||||
next_credit_reset_date: undefined,
|
||||
},
|
||||
isPending: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
|
||||
|
||||
@ -7,11 +7,11 @@ import QuotaPanel from '../quota-panel'
|
||||
let mockWorkspaceData: {
|
||||
trial_credits: number
|
||||
trial_credits_used: number
|
||||
next_credit_reset_date: string
|
||||
next_credit_reset_date: number
|
||||
} | undefined = {
|
||||
trial_credits: 100,
|
||||
trial_credits_used: 30,
|
||||
next_credit_reset_date: '2024-12-31',
|
||||
next_credit_reset_date: 1735603200,
|
||||
}
|
||||
let mockWorkspaceIsPending = false
|
||||
let mockTrialModels: string[] | undefined = ['langgenius/openai/openai']
|
||||
@ -32,11 +32,18 @@ vi.mock('@/app/components/base/icons/src/public/llm', () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useCurrentWorkspace: () => ({
|
||||
data: mockWorkspaceData,
|
||||
isPending: mockWorkspaceIsPending,
|
||||
}),
|
||||
vi.mock('../use-trial-credits', () => ({
|
||||
useTrialCredits: () => {
|
||||
const totalCredits = mockWorkspaceData?.trial_credits ?? 0
|
||||
const credits = Math.max(totalCredits - (mockWorkspaceData?.trial_credits_used ?? 0), 0)
|
||||
return {
|
||||
credits,
|
||||
totalCredits,
|
||||
isExhausted: credits <= 0,
|
||||
isLoading: mockWorkspaceIsPending && !mockWorkspaceData,
|
||||
nextCreditResetDate: mockWorkspaceData?.next_credit_reset_date,
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
const renderQuotaPanel = (ui: ReactElement) => renderWithSystemFeatures(ui, {
|
||||
@ -78,7 +85,7 @@ describe('QuotaPanel', () => {
|
||||
mockWorkspaceData = {
|
||||
trial_credits: 100,
|
||||
trial_credits_used: 30,
|
||||
next_credit_reset_date: '2024-12-31',
|
||||
next_credit_reset_date: 1735603200,
|
||||
}
|
||||
mockWorkspaceIsPending = false
|
||||
mockTrialModels = ['langgenius/openai/openai']
|
||||
@ -118,7 +125,7 @@ describe('QuotaPanel', () => {
|
||||
mockWorkspaceData = {
|
||||
trial_credits: 10,
|
||||
trial_credits_used: 999,
|
||||
next_credit_reset_date: '',
|
||||
next_credit_reset_date: 0,
|
||||
}
|
||||
|
||||
renderQuotaPanel(<QuotaPanel providers={mockProviders} />)
|
||||
|
||||
@ -1,20 +1,34 @@
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import { useTrialCredits } from '../use-trial-credits'
|
||||
|
||||
const mockUseCurrentWorkspace = vi.fn()
|
||||
const { mockUseQuery } = vi.hoisted(() => ({
|
||||
mockUseQuery: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useCurrentWorkspace: () => mockUseCurrentWorkspace(),
|
||||
vi.mock('@tanstack/react-query', () => ({
|
||||
useQuery: () => mockUseQuery(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/client', () => ({
|
||||
consoleQuery: {
|
||||
workspaces: {
|
||||
current: {
|
||||
post: {
|
||||
queryOptions: () => ({ queryKey: ['console', 'workspaces', 'current', 'post'] }),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
describe('useTrialCredits', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseCurrentWorkspace.mockReturnValue({
|
||||
mockUseQuery.mockReturnValue({
|
||||
data: {
|
||||
trial_credits: 100,
|
||||
trial_credits_used: 40,
|
||||
next_credit_reset_date: '2026-04-01',
|
||||
next_credit_reset_date: 1775001600,
|
||||
},
|
||||
isPending: false,
|
||||
})
|
||||
@ -29,16 +43,16 @@ describe('useTrialCredits', () => {
|
||||
totalCredits: 100,
|
||||
isExhausted: false,
|
||||
isLoading: false,
|
||||
nextCreditResetDate: '2026-04-01',
|
||||
nextCreditResetDate: 1775001600,
|
||||
})
|
||||
})
|
||||
|
||||
it('should keep the hook out of loading state during a background refetch', () => {
|
||||
mockUseCurrentWorkspace.mockReturnValue({
|
||||
mockUseQuery.mockReturnValue({
|
||||
data: {
|
||||
trial_credits: 80,
|
||||
trial_credits_used: 20,
|
||||
next_credit_reset_date: '2026-05-01',
|
||||
next_credit_reset_date: 1777593600,
|
||||
},
|
||||
isPending: true,
|
||||
})
|
||||
@ -53,7 +67,7 @@ describe('useTrialCredits', () => {
|
||||
|
||||
describe('when workspace data is missing or exhausted', () => {
|
||||
it('should report loading while the first workspace request is pending', () => {
|
||||
mockUseCurrentWorkspace.mockReturnValue({
|
||||
mockUseQuery.mockReturnValue({
|
||||
data: undefined,
|
||||
isPending: true,
|
||||
})
|
||||
@ -70,7 +84,7 @@ describe('useTrialCredits', () => {
|
||||
})
|
||||
|
||||
it('should clamp negative remaining credits to zero', () => {
|
||||
mockUseCurrentWorkspace.mockReturnValue({
|
||||
mockUseQuery.mockReturnValue({
|
||||
data: {
|
||||
trial_credits: 10,
|
||||
trial_credits_used: 99,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import type { ICurrentWorkspace } from '@/models/common'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import CreditsExhaustedAlert from './credits-exhausted-alert'
|
||||
|
||||
const baseWorkspace: ICurrentWorkspace = {
|
||||
@ -20,7 +21,7 @@ function createSeededQueryClient(overrides?: Partial<ICurrentWorkspace>) {
|
||||
const qc = new QueryClient({
|
||||
defaultOptions: { queries: { refetchOnWindowFocus: false, retry: false } },
|
||||
})
|
||||
qc.setQueryData(['common', 'current-workspace'], { ...baseWorkspace, ...overrides })
|
||||
qc.setQueryData(consoleQuery.workspaces.current.post.queryKey(), { ...baseWorkspace, ...overrides })
|
||||
return qc
|
||||
}
|
||||
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import { useCurrentWorkspace } from '@/service/use-common'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
|
||||
export const useTrialCredits = () => {
|
||||
const { data: currentWorkspace, isPending } = useCurrentWorkspace()
|
||||
const { data: currentWorkspace, isPending } = useQuery(consoleQuery.workspaces.current.post.queryOptions())
|
||||
const totalCredits = currentWorkspace?.trial_credits ?? 0
|
||||
const credits = Math.max(totalCredits - (currentWorkspace?.trial_credits_used ?? 0), 0)
|
||||
|
||||
|
||||
@ -1,16 +1,14 @@
|
||||
'use client'
|
||||
|
||||
import type { ICurrentWorkspace } from '@/models/common'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { useEffect, useState } from 'react'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { userProfileQueryOptions } from '@/features/account-profile/client'
|
||||
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { post } from '@/service/base'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { deviceLookup } from '@/service/device-flow'
|
||||
import { systemFeaturesQueryOptions } from '@/service/system-features'
|
||||
import { commonQueryKeys } from '@/service/use-common'
|
||||
import AuthorizeAccount from './components/authorize-account'
|
||||
import AuthorizeSSO from './components/authorize-sso'
|
||||
import Chooser from './components/chooser'
|
||||
@ -52,9 +50,8 @@ export default function DevicePage() {
|
||||
refetchOnMount: false,
|
||||
})
|
||||
const account = userResp?.profile
|
||||
const { data: currentWorkspace } = useQuery<ICurrentWorkspace>({
|
||||
queryKey: commonQueryKeys.currentWorkspace,
|
||||
queryFn: () => post<ICurrentWorkspace>('/workspaces/current'),
|
||||
const { data: currentWorkspace } = useQuery({
|
||||
...consoleQuery.workspaces.current.post.queryOptions(),
|
||||
enabled: !!account && !profileErr,
|
||||
retry: false,
|
||||
refetchOnWindowFocus: false,
|
||||
@ -174,7 +171,7 @@ export default function DevicePage() {
|
||||
accountEmail={account?.email}
|
||||
accountName={account?.name}
|
||||
accountAvatarUrl={account?.avatar_url ?? null}
|
||||
defaultWorkspace={currentWorkspace?.name}
|
||||
defaultWorkspace={currentWorkspace?.name ?? undefined}
|
||||
onApproved={() => setView({ kind: 'success' })}
|
||||
onDenied={() => setView({ kind: 'error_expired' })}
|
||||
onError={e => setErrMsg(e)}
|
||||
|
||||
@ -23,7 +23,7 @@ import {
|
||||
useRouter,
|
||||
useSearchParams,
|
||||
} from '@/next/navigation'
|
||||
import { consoleClient } from '@/service/client'
|
||||
import { consoleClient, consoleQuery } from '@/service/client'
|
||||
import { switchWorkspace } from '@/service/common'
|
||||
import { commonQueryKeys } from '@/service/use-common'
|
||||
import {
|
||||
@ -129,7 +129,7 @@ const EducationApplyAgeContent = () => {
|
||||
try {
|
||||
await switchWorkspace({ url: '/workspaces/switch', body: { tenant_id: tenantId } })
|
||||
await Promise.all([
|
||||
queryClient.invalidateQueries({ queryKey: commonQueryKeys.currentWorkspace }),
|
||||
queryClient.invalidateQueries({ queryKey: consoleQuery.workspaces.current.post.key() }),
|
||||
queryClient.invalidateQueries({ queryKey: commonQueryKeys.workspaces }),
|
||||
])
|
||||
onPlanInfoChanged()
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
'use client'
|
||||
|
||||
import type { PostWorkspacesCurrentResponse } from '@dify/contracts/api/console/workspaces/types.gen'
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import type { ICurrentWorkspace, LangGeniusVersionResponse, UserProfileResponse } from '@/models/common'
|
||||
import { useQueryClient, useSuspenseQuery } from '@tanstack/react-query'
|
||||
import { useQuery, useQueryClient, useSuspenseQuery } from '@tanstack/react-query'
|
||||
import { useCallback, useEffect, useMemo } from 'react'
|
||||
import { setUserId, setUserProperties } from '@/app/components/base/amplitude'
|
||||
import { setZendeskConversationFields } from '@/app/components/base/zendesk/utils'
|
||||
@ -17,9 +18,9 @@ import {
|
||||
} from '@/context/app-context'
|
||||
import { env } from '@/env'
|
||||
import { userProfileQueryOptions } from '@/features/account-profile/client'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { systemFeaturesQueryOptions } from '@/service/system-features'
|
||||
import {
|
||||
useCurrentWorkspace,
|
||||
useLangGeniusVersion,
|
||||
} from '@/service/use-common'
|
||||
|
||||
@ -27,18 +28,52 @@ type AppContextProviderProps = {
|
||||
children: ReactNode
|
||||
}
|
||||
|
||||
const workspaceRoles = new Set<ICurrentWorkspace['role']>(['owner', 'admin', 'editor', 'dataset_operator', 'normal'])
|
||||
|
||||
const resolveWorkspaceRole = (role: PostWorkspacesCurrentResponse['role']): ICurrentWorkspace['role'] => {
|
||||
if (role && workspaceRoles.has(role as ICurrentWorkspace['role']))
|
||||
return role as ICurrentWorkspace['role']
|
||||
|
||||
return initialWorkspaceInfo.role
|
||||
}
|
||||
|
||||
const normalizeCurrentWorkspace = (workspace?: PostWorkspacesCurrentResponse): ICurrentWorkspace => {
|
||||
if (!workspace)
|
||||
return initialWorkspaceInfo
|
||||
|
||||
return {
|
||||
id: workspace.id,
|
||||
name: workspace.name ?? initialWorkspaceInfo.name,
|
||||
plan: workspace.plan ?? initialWorkspaceInfo.plan,
|
||||
status: workspace.status ?? initialWorkspaceInfo.status,
|
||||
created_at: workspace.created_at ?? initialWorkspaceInfo.created_at,
|
||||
role: resolveWorkspaceRole(workspace.role),
|
||||
providers: initialWorkspaceInfo.providers,
|
||||
trial_credits: workspace.trial_credits ?? initialWorkspaceInfo.trial_credits,
|
||||
trial_credits_used: workspace.trial_credits_used ?? initialWorkspaceInfo.trial_credits_used,
|
||||
next_credit_reset_date: workspace.next_credit_reset_date ?? initialWorkspaceInfo.next_credit_reset_date,
|
||||
trial_end_reason: workspace.trial_end_reason ?? undefined,
|
||||
custom_config: workspace.custom_config
|
||||
? {
|
||||
remove_webapp_brand: workspace.custom_config.remove_webapp_brand ?? undefined,
|
||||
replace_webapp_logo: workspace.custom_config.replace_webapp_logo ?? undefined,
|
||||
}
|
||||
: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) => {
|
||||
const queryClient = useQueryClient()
|
||||
const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions())
|
||||
const { data: userProfileResp } = useSuspenseQuery(userProfileQueryOptions())
|
||||
const { data: currentWorkspaceResp, isPending: isLoadingCurrentWorkspace, isFetching: isValidatingCurrentWorkspace } = useCurrentWorkspace()
|
||||
const { data: currentWorkspaceResp, isPending: isLoadingCurrentWorkspace, isFetching: isValidatingCurrentWorkspace } = useQuery(consoleQuery.workspaces.current.post.queryOptions())
|
||||
const langGeniusVersionQuery = useLangGeniusVersion(
|
||||
userProfileResp?.meta.currentVersion,
|
||||
!systemFeatures.branding.enabled,
|
||||
)
|
||||
|
||||
const userProfile = useMemo<UserProfileResponse>(() => userProfileResp?.profile || userProfilePlaceholder, [userProfileResp?.profile])
|
||||
const currentWorkspace = useMemo<ICurrentWorkspace>(() => currentWorkspaceResp || initialWorkspaceInfo, [currentWorkspaceResp])
|
||||
const currentWorkspace = useMemo<ICurrentWorkspace>(() => normalizeCurrentWorkspace(currentWorkspaceResp), [currentWorkspaceResp])
|
||||
const langGeniusVersionInfo = useMemo<LangGeniusVersionResponse>(() => {
|
||||
if (!userProfileResp?.meta?.currentVersion || !langGeniusVersionQuery.data)
|
||||
return initialLangGeniusVersionInfo
|
||||
@ -64,7 +99,7 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
|
||||
}, [queryClient])
|
||||
|
||||
const mutateCurrentWorkspace = useCallback(() => {
|
||||
queryClient.invalidateQueries({ queryKey: ['common', 'current-workspace'] })
|
||||
queryClient.invalidateQueries({ queryKey: consoleQuery.workspaces.current.post.key() })
|
||||
}, [queryClient])
|
||||
|
||||
// #region Zendesk conversation fields
|
||||
|
||||
@ -10,7 +10,6 @@ import type {
|
||||
CodeBasedExtension,
|
||||
CommonResponse,
|
||||
FileUploadConfigResponse,
|
||||
ICurrentWorkspace,
|
||||
IWorkspace,
|
||||
LangGeniusVersionResponse,
|
||||
Member,
|
||||
@ -26,7 +25,6 @@ const NAME_SPACE = 'common'
|
||||
|
||||
export const commonQueryKeys = {
|
||||
fileUploadConfig: [NAME_SPACE, 'file-upload-config'] as const,
|
||||
currentWorkspace: [NAME_SPACE, 'current-workspace'] as const,
|
||||
workspaces: [NAME_SPACE, 'workspaces'] as const,
|
||||
members: [NAME_SPACE, 'members'] as const,
|
||||
filePreview: (fileID: string) => [NAME_SPACE, 'file-preview', fileID] as const,
|
||||
@ -68,13 +66,6 @@ export const useLangGeniusVersion = (currentVersion?: string | null, enabled?: b
|
||||
})
|
||||
}
|
||||
|
||||
export const useCurrentWorkspace = () => {
|
||||
return useQuery<ICurrentWorkspace>({
|
||||
queryKey: commonQueryKeys.currentWorkspace,
|
||||
queryFn: () => post<ICurrentWorkspace>('/workspaces/current'),
|
||||
})
|
||||
}
|
||||
|
||||
export const useWorkspaces = () => {
|
||||
return useQuery<{ workspaces: IWorkspace[] }>({
|
||||
queryKey: commonQueryKeys.workspaces,
|
||||
|
||||
Reference in New Issue
Block a user