Compare commits

..

2 Commits

Author SHA1 Message Date
58da51c1ba [autofix.ci] apply automated fixes 2026-05-30 04:40:38 +00:00
yyh
9d093f71ed fix(web): use generated current workspace query 2026-05-30 12:35:19 +08:00
29 changed files with 184 additions and 394 deletions

View File

@ -209,11 +209,6 @@ class MCPProviderBasePayload(BaseModel):
configuration: dict[str, Any] | None = Field(default_factory=dict)
headers: dict[str, Any] | None = Field(default_factory=dict)
authentication: dict[str, Any] | None = Field(default_factory=dict)
# M3 — user-identity forwarding (M2 backend already supports these on the
# service layer). Defaults preserve pre-M3 behavior for clients that don't
# send the fields yet.
forward_user_identity: bool = False
identity_mode: Literal["off", "idp_token"] = "off"
class MCPProviderCreatePayload(MCPProviderBasePayload):
@ -990,8 +985,6 @@ class ToolProviderMCPApi(Resource):
headers=payload.headers or {},
configuration=configuration,
authentication=authentication,
forward_user_identity=payload.forward_user_identity,
identity_mode=payload.identity_mode,
)
# 2) Try to fetch tools immediately after creation so they appear without a second save.
@ -1059,8 +1052,6 @@ class ToolProviderMCPApi(Resource):
configuration=configuration,
authentication=authentication,
validation_result=validation_result,
forward_user_identity=payload.forward_user_identity,
identity_mode=payload.identity_mode,
)
return {"result": "success"}

View File

@ -29,7 +29,7 @@ from controllers.console.wraps import (
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField, to_timestamp
from libs.helper import TimestampField, dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from services.account_service import TenantService
@ -56,6 +56,11 @@ class WorkspaceCustomConfigPayload(BaseModel):
replace_webapp_logo: str | None = None
class WorkspaceCustomConfigResponse(ResponseModel):
remove_webapp_brand: bool | None = None
replace_webapp_logo: str | None = None
class WorkspaceInfoPayload(BaseModel):
name: str
@ -69,7 +74,7 @@ class TenantInfoResponse(ResponseModel):
role: str | None = None
in_trial: bool | None = None
trial_end_reason: str | None = None
custom_config: dict | None = None
custom_config: WorkspaceCustomConfigResponse | None = None
trial_credits: int | None = None
trial_credits_used: int | None = None
next_credit_reset_date: int | None = None
@ -101,9 +106,13 @@ register_schema_models(
SwitchWorkspacePayload,
WorkspaceCustomConfigPayload,
WorkspaceInfoPayload,
TenantInfoResponse,
)
register_response_schema_models(console_ns, WorkspacePermissionResponse)
register_response_schema_models(
console_ns,
TenantInfoResponse,
WorkspaceCustomConfigResponse,
WorkspacePermissionResponse,
)
provider_fields = {
"provider_name": fields.String,
@ -238,13 +247,7 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
return (
TenantInfoResponse.model_validate(
WorkspaceService.get_tenant_info(tenant),
from_attributes=True,
).model_dump(mode="json"),
200,
)
return dump_response(TenantInfoResponse, WorkspaceService.get_tenant_info(tenant)), 200
@console_ns.route("/workspaces/switch")

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from pydantic import BaseModel
@ -76,14 +76,6 @@ class MCPProviderEntity(BaseModel):
created_at: datetime
updated_at: datetime
# M2 — user-identity forwarding. When forward_user_identity is True AND
# identity_mode is "idp_token", the MCP tool runtime asks dify-enterprise
# to mint a fresh SSO id_token for the calling user and stamps it on the
# outbound MCP request as `Authorization: Bearer <token>`. Defaults keep
# pre-M2 providers unchanged (no forwarding).
forward_user_identity: bool = False
identity_mode: Literal["off", "idp_token"] = "off"
@classmethod
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
"""Create entity from database model with decryption"""
@ -104,8 +96,6 @@ class MCPProviderEntity(BaseModel):
icon=db_provider.icon or "",
created_at=db_provider.created_at,
updated_at=db_provider.updated_at,
forward_user_identity=db_provider.forward_user_identity,
identity_mode=db_provider.identity_mode, # type: ignore[arg-type]
)
@property
@ -180,8 +170,6 @@ class MCPProviderEntity(BaseModel):
"updated_at": int(self.updated_at.timestamp()),
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
"forward_user_identity": self.forward_user_identity,
"identity_mode": self.identity_mode,
}
# Add configuration

View File

@ -54,12 +54,6 @@ class ToolProviderApiEntity(BaseModel):
configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
# M3 — user-identity forwarding flags. Round-tripped through the console
# API so the create/edit modal can hydrate the toggle state.
forward_user_identity: bool = Field(
default=False, description="Whether Dify forwards the calling user's SSO identity to this MCP server"
)
identity_mode: str = Field(default="off", description="Identity-forwarding mechanism: 'off' or 'idp_token'")
# Workflow
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")
@ -98,10 +92,6 @@ class ToolProviderApiEntity(BaseModel):
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
# M3 — forwarding flags. Always emit (False/"off" are valid
# values that the UI must hydrate, not skip).
optional_fields["forward_user_identity"] = self.forward_user_identity
optional_fields["identity_mode"] = self.identity_mode
case ToolProviderType.WORKFLOW:
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
case _:

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Self
from typing import Any, Self
from core.entities.mcp_provider import MCPProviderEntity
from core.mcp.types import Tool as RemoteMCPTool
@ -28,8 +28,6 @@ class MCPToolProviderController(ToolProviderController):
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
forward_user_identity: bool = False,
identity_mode: Literal["off", "idp_token"] = "off",
):
super().__init__(entity)
self.entity: ToolProviderEntityWithPlugin = entity
@ -39,8 +37,6 @@ class MCPToolProviderController(ToolProviderController):
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.forward_user_identity = forward_user_identity
self.identity_mode: Literal["off", "idp_token"] = identity_mode
@property
def provider_type(self) -> ToolProviderType:
@ -109,8 +105,6 @@ class MCPToolProviderController(ToolProviderController):
headers=entity.headers,
timeout=entity.timeout,
sse_read_timeout=entity.sse_read_timeout,
forward_user_identity=entity.forward_user_identity,
identity_mode=entity.identity_mode,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
@ -140,8 +134,6 @@ class MCPToolProviderController(ToolProviderController):
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
forward_user_identity=self.forward_user_identity,
identity_mode=self.identity_mode,
)
def get_tools(self) -> list[MCPTool]:
@ -159,8 +151,6 @@ class MCPToolProviderController(ToolProviderController):
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
forward_user_identity=self.forward_user_identity,
identity_mode=self.identity_mode,
)
for tool_entity in self.entity.tools
]

View File

@ -4,7 +4,7 @@ import base64
import json
import logging
from collections.abc import Generator, Mapping
from typing import Any, Literal, cast
from typing import Any, cast
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
@ -38,8 +38,6 @@ class MCPTool(Tool):
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
forward_user_identity: bool = False,
identity_mode: Literal["off", "idp_token"] = "off",
):
super().__init__(entity, runtime)
self.tenant_id = tenant_id
@ -49,8 +47,6 @@ class MCPTool(Tool):
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.forward_user_identity = forward_user_identity
self.identity_mode: Literal["off", "idp_token"] = identity_mode
self._latest_usage = LLMUsage.empty_usage()
def tool_provider_type(self) -> ToolProviderType:
@ -64,7 +60,7 @@ class MCPTool(Tool):
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
result = self.invoke_remote_mcp_tool(tool_parameters, user_id=user_id, app_id=app_id)
result = self.invoke_remote_mcp_tool(tool_parameters)
# Extract usage metadata from MCP protocol's _meta field
self._latest_usage = self._derive_usage_from_result(result)
@ -238,8 +234,6 @@ class MCPTool(Tool):
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
forward_user_identity=self.forward_user_identity,
identity_mode=self.identity_mode,
)
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
@ -252,12 +246,7 @@ class MCPTool(Tool):
if value is not None and not (isinstance(value, str) and value.strip() == "")
}
def invoke_remote_mcp_tool(
self,
tool_parameters: dict[str, Any],
user_id: str | None = None,
app_id: str | None = None,
) -> CallToolResult:
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)
@ -282,14 +271,6 @@ class MCPTool(Tool):
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
# User-identity forwarding: if enabled on this provider, ask the
# enterprise side to mint a fresh SSO id_token (audience-scoped to
# the MCP server's URL per RFC 8707) and stamp it as Authorization.
# This OVERRIDES any Authorization already on the request — the
# forwarded identity is what the MCP server should trust.
if self.forward_user_identity and self.identity_mode == "idp_token" and user_id:
self._inject_forwarded_identity(headers, user_id=user_id, app_id=app_id, audience=server_url)
# Step 2: Session is now closed, perform network operations without holding database connection
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
try:
@ -305,31 +286,3 @@ class MCPTool(Tool):
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
def _inject_forwarded_identity(
self,
headers: dict[str, str],
*,
user_id: str,
app_id: str | None,
audience: str,
) -> None:
"""Call the enterprise IssueMCPToken endpoint and stamp Authorization.
Errors are surfaced as ToolInvokeError so the workflow halts with a
clear message instead of silently dropping identity and hitting the
MCP server unauthenticated.
"""
from services.enterprise.base import MCPTokenError
from services.enterprise.enterprise_service import EnterpriseService
try:
token, _expires_at = EnterpriseService.issue_mcp_token(
user_id=user_id,
tenant_id=self.tenant_id,
app_id=app_id,
audience=audience,
)
except MCPTokenError as e:
raise ToolInvokeError(f"Failed to obtain forwarded identity token: {e}") from e
headers["Authorization"] = f"Bearer {token}"

View File

@ -1,56 +0,0 @@
"""add identity mode to mcp tool provider
Revision ID: 3df4dbcc1e21
Revises: 7885bd53f9a9
Create Date: 2026-05-29 15:00:00.000000
Adds two columns to `tool_mcp_providers` that drive the M2 MCP user-identity
forwarding feature:
* `forward_user_identity` (bool, default false) — master switch per provider.
* `identity_mode` (string, default "off") — which forwarding mechanism to use:
"off" — no header forwarded (default; pre-M2 behaviour).
"idp_token" — call dify-enterprise /inner/api/mcp/issue-token, stamp
the returned id_token on the outbound MCP request as
`Authorization: Bearer <token>`.
The columns are filled with safe defaults for existing rows so older providers
keep their current behaviour (no identity forwarding) until an admin opts in.
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "3df4dbcc1e21"
down_revision = "7885bd53f9a9"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
"tool_mcp_providers",
sa.Column(
"forward_user_identity",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
op.add_column(
"tool_mcp_providers",
sa.Column(
"identity_mode",
sa.String(length=32),
nullable=False,
server_default=sa.text("'off'"),
),
)
def downgrade():
op.drop_column("tool_mcp_providers", "identity_mode")
op.drop_column("tool_mcp_providers", "forward_user_identity")

View File

@ -343,21 +343,6 @@ class MCPToolProvider(TypeBase):
# encrypted headers for MCP server requests
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
# M2 (MCP user-identity forwarding) — master switch per provider. When True
# AND identity_mode is "idp_token", workflows that invoke tools on this
# provider will have the caller's SSO id_token stamped on the outbound
# request as `Authorization: Bearer …`. Off by default so existing
# providers retain pre-M2 behaviour.
forward_user_identity: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false"), default=False
)
# M2 — which identity-forwarding mechanism to use. Reserved values:
# "off" — no forwarding (default).
# "idp_token" — forward a Bearer id_token minted by dify-enterprise.
identity_mode: Mapped[str] = mapped_column(
sa.String(32), nullable=False, server_default=sa.text("'off'"), default="off"
)
def load_user(self) -> Account | None:
return db.session.scalar(select(Account).where(Account.id == self.user_id))

View File

@ -13761,12 +13761,10 @@ Enum class for large language model mode.
| ---- | ---- | ----------- | -------- |
| authentication | object | | No |
| configuration | object | | No |
| forward_user_identity | boolean | | No |
| headers | object | | No |
| icon | string | | Yes |
| icon_background | string | | No |
| icon_type | string | | Yes |
| identity_mode | string | *Enum:* `"idp_token"`, `"off"` | No |
| name | string | | Yes |
| server_identifier | string | | Yes |
| server_url | string | | Yes |
@ -13783,12 +13781,10 @@ Enum class for large language model mode.
| ---- | ---- | ----------- | -------- |
| authentication | object | | No |
| configuration | object | | No |
| forward_user_identity | boolean | | No |
| headers | object | | No |
| icon | string | | Yes |
| icon_background | string | | No |
| icon_type | string | | Yes |
| identity_mode | string | *Enum:* `"idp_token"`, `"off"` | No |
| name | string | | Yes |
| provider_id | string | | Yes |
| server_identifier | string | | Yes |
@ -15211,7 +15207,7 @@ Tag type
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| created_at | integer | | No |
| custom_config | object | | No |
| custom_config | [WorkspaceCustomConfigResponse](#workspacecustomconfigresponse) | | No |
| id | string | | Yes |
| in_trial | boolean | | No |
| name | string | | No |
@ -16334,6 +16330,13 @@ Workflow tool configuration
| remove_webapp_brand | boolean | | No |
| replace_webapp_logo | string | | No |
#### WorkspaceCustomConfigResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| remove_webapp_brand | boolean | | No |
| replace_webapp_logo | string | | No |
#### WorkspaceInfoPayload
| Name | Type | Description | Required |

View File

@ -12,34 +12,8 @@ from services.errors.enterprise import (
EnterpriseAPIForbiddenError,
EnterpriseAPINotFoundError,
EnterpriseAPIUnauthorizedError,
EnterpriseServiceError,
)
# M2 — IssueMCPToken specific errors. Co-located here (rather than in
# services/errors/enterprise.py) because services.enterprise.base is part of
# the leaf-mounted file set the local dev override applies; the errors module
# stays at the EE image's baked-in version.
class MCPTokenError(EnterpriseServiceError):
"""Generic failure of the IssueMCPToken RPC."""
class MCPNoRefreshTokenError(MCPTokenError):
"""The user has no stored SSO refresh_token on the enterprise side.
The workflow should ask them to re-authenticate."""
def __init__(self, description: str = ""):
super().__init__(description, status_code=428)
class MCPIdentityRefreshError(MCPTokenError):
"""The enterprise side tried to refresh the user's SSO refresh_token
against the IdP and failed (revoked/expired/IdP error)."""
def __init__(self, description: str = ""):
super().__init__(description, status_code=401)
logger = logging.getLogger(__name__)

View File

@ -11,16 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
from extensions.ext_redis import redis_client
from services.enterprise.base import (
EnterpriseRequest,
MCPIdentityRefreshError,
MCPNoRefreshTokenError,
MCPTokenError,
)
from services.errors.enterprise import (
EnterpriseAPIError,
EnterpriseAPIUnauthorizedError,
)
from services.enterprise.base import EnterpriseRequest
if TYPE_CHECKING:
from services.feature_service import LicenseStatus
@ -130,62 +121,6 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def issue_mcp_token(
cls,
user_id: str,
tenant_id: str,
app_id: str | None,
audience: str,
) -> tuple[str, int]:
"""Mint a short-lived SSO id_token (or OAuth2 access_token) representing
the calling Dify user, audience-scoped to the given MCP server identifier.
Used by MCPTool.invoke_remote_mcp_tool to stamp `Authorization: Bearer
<token>` on outbound MCP requests when the provider has
forward_user_identity=True and identity_mode="idp_token".
Returns:
(token, expires_at_unix_seconds)
Raises:
MCPNoRefreshTokenError: user has no stored SSO refresh_token on the
enterprise side; surface to the workflow as "please log in via SSO".
MCPIdentityRefreshError: enterprise tried to refresh against the IdP
and the IdP rejected (revoked/expired session).
MCPTokenError: any other failure of the enterprise endpoint.
"""
try:
response = EnterpriseRequest.send_request(
"POST",
"/mcp/issue-token",
json={
"user_id": user_id,
"tenant_id": tenant_id,
"app_id": app_id or "",
"audience": audience,
},
)
except EnterpriseAPIUnauthorizedError as e:
# Enterprise side returns 401 when the IdP rejected the refresh.
raise MCPIdentityRefreshError(str(e) or "identity refresh failed; please re-authenticate") from e
except EnterpriseAPIError as e:
# Map the 428 PreconditionRequired we emit on no-stored-refresh-token.
if getattr(e, "status_code", None) == 428:
raise MCPNoRefreshTokenError(
str(e) or "user has no stored SSO refresh token; please re-authenticate"
) from e
raise MCPTokenError(f"issue_mcp_token failed: {e}") from e
if not isinstance(response, dict):
raise MCPTokenError("invalid response shape from enterprise /mcp/issue-token")
token = response.get("token")
expires_at = response.get("expires_at")
if not token or not isinstance(token, str) or not isinstance(expires_at, int):
raise MCPTokenError(f"missing token/expires_at in enterprise response: {response}")
return token, expires_at
@classmethod
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
return EnterpriseRequest.send_request(

View File

@ -4,7 +4,7 @@ import logging
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Literal
from typing import Any
from urllib.parse import urlparse
from pydantic import BaseModel, Field
@ -136,8 +136,6 @@ class MCPToolManageService:
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
headers: dict[str, str] | None = None,
forward_user_identity: bool = False,
identity_mode: Literal["off", "idp_token"] = "off",
) -> ToolProviderApiEntity:
"""Create a new MCP provider."""
# Validate URL format
@ -173,8 +171,6 @@ class MCPToolManageService:
sse_read_timeout=configuration.sse_read_timeout,
encrypted_headers=encrypted_headers,
encrypted_credentials=encrypted_credentials,
forward_user_identity=forward_user_identity,
identity_mode=identity_mode,
)
self._session.add(mcp_tool)
@ -198,8 +194,6 @@ class MCPToolManageService:
configuration: MCPConfiguration,
authentication: MCPAuthentication | None = None,
validation_result: ServerUrlValidationResult | None = None,
forward_user_identity: bool | None = None,
identity_mode: Literal["off", "idp_token"] | None = None,
) -> None:
"""
Update an MCP provider.
@ -261,14 +255,6 @@ class MCPToolManageService:
if authentication and authentication.client_id:
mcp_provider.encrypted_credentials = self._process_credentials(authentication, mcp_provider, tenant_id)
# Update user-identity forwarding settings if provided.
# None means "leave unchanged" so this stays backwards-compatible
# with existing callers that don't know about M2.
if forward_user_identity is not None:
mcp_provider.forward_user_identity = forward_user_identity
if identity_mode is not None:
mcp_provider.identity_mode = identity_mode
# Flush changes to database
self._session.flush()

View File

@ -435,6 +435,23 @@ class TestTenantInfoResponse:
assert payload["plan"] == "team"
assert payload["created_at"] == int(created_at.timestamp())
def test_tenant_info_response_has_typed_custom_config(self):
payload = TenantInfoResponse.model_validate(
{
"id": "t1",
"custom_config": {
"remove_webapp_brand": True,
"replace_webapp_logo": "logo-file-id",
"ignored": "value",
},
}
).model_dump(mode="json")
assert payload["custom_config"] == {
"remove_webapp_brand": True,
"replace_webapp_logo": "logo-file-id",
}
class TestSwitchWorkspaceApi:
def test_switch_success(self, app: Flask):

View File

@ -4904,11 +4904,6 @@
"count": 1
}
},
"web/app/device/page.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/education-apply/hooks.ts": {
"react/set-state-in-effect": {
"count": 5

View File

@ -4,16 +4,8 @@ import { oc } from '@orpc/contract'
import { zPostInfoResponse } from './zod.gen'
/**
* Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.
*
* @deprecated
*/
export const post = oc
.route({
deprecated: true,
description:
'Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.',
inputStructure: 'detailed',
method: 'POST',
operationId: 'postInfo',

View File

@ -6,9 +6,7 @@ export type ClientOptions = {
export type TenantInfoResponse = {
created_at?: number | null
custom_config?: {
[key: string]: unknown
} | null
custom_config?: WorkspaceCustomConfigResponse
id: string
in_trial?: boolean | null
name?: string | null
@ -21,6 +19,11 @@ export type TenantInfoResponse = {
trial_end_reason?: string | null
}
export type WorkspaceCustomConfigResponse = {
remove_webapp_brand?: boolean | null
replace_webapp_logo?: string | null
}
export type PostInfoData = {
body?: never
path?: never

View File

@ -2,12 +2,20 @@
import * as z from 'zod'
/**
* WorkspaceCustomConfigResponse
*/
export const zWorkspaceCustomConfigResponse = z.object({
remove_webapp_brand: z.boolean().nullish(),
replace_webapp_logo: z.string().nullish(),
})
/**
* TenantInfoResponse
*/
export const zTenantInfoResponse = z.object({
created_at: z.int().nullish(),
custom_config: z.record(z.string(), z.unknown()).nullish(),
custom_config: zWorkspaceCustomConfigResponse.optional(),
id: z.string(),
in_trial: z.boolean().nullish(),
name: z.string().nullish(),

View File

@ -3682,16 +3682,8 @@ export const triggers = {
get: get58,
}
/**
* Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.
*
* @deprecated
*/
export const post63 = oc
.route({
deprecated: true,
description:
'Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.',
inputStructure: 'detailed',
method: 'POST',
operationId: 'postWorkspacesCurrent',

View File

@ -6,9 +6,7 @@ export type ClientOptions = {
export type TenantInfoResponse = {
created_at?: number | null
custom_config?: {
[key: string]: unknown
} | null
custom_config?: WorkspaceCustomConfigResponse
id: string
in_trial?: boolean | null
name?: string | null
@ -392,14 +390,12 @@ export type McpProviderCreatePayload = {
configuration?: {
[key: string]: unknown
} | null
forward_user_identity?: boolean
headers?: {
[key: string]: unknown
} | null
icon: string
icon_background?: string
icon_type: string
identity_mode?: 'idp_token' | 'off'
name: string
server_identifier: string
server_url: string
@ -412,14 +408,12 @@ export type McpProviderUpdatePayload = {
configuration?: {
[key: string]: unknown
} | null
forward_user_identity?: boolean
headers?: {
[key: string]: unknown
} | null
icon: string
icon_background?: string
icon_type: string
identity_mode?: 'idp_token' | 'off'
name: string
provider_id: string
server_identifier: string
@ -504,6 +498,11 @@ export type SwitchWorkspacePayload = {
tenant_id: string
}
export type WorkspaceCustomConfigResponse = {
remove_webapp_brand?: boolean | null
replace_webapp_logo?: string | null
}
export type AccountWithRole = {
avatar?: string | null
created_at?: number | null

View File

@ -2,24 +2,6 @@
import * as z from 'zod'
/**
* TenantInfoResponse
*/
export const zTenantInfoResponse = z.object({
created_at: z.int().nullish(),
custom_config: z.record(z.string(), z.unknown()).nullish(),
id: z.string(),
in_trial: z.boolean().nullish(),
name: z.string().nullish(),
next_credit_reset_date: z.int().nullish(),
plan: z.string().nullish(),
role: z.string().nullish(),
status: z.string().nullish(),
trial_credits: z.int().nullish(),
trial_credits_used: z.int().nullish(),
trial_end_reason: z.string().nullish(),
})
/**
* SimpleResultResponse
*/
@ -361,12 +343,10 @@ export const zMcpProviderDeletePayload = z.object({
export const zMcpProviderCreatePayload = z.object({
authentication: z.record(z.string(), z.unknown()).nullish(),
configuration: z.record(z.string(), z.unknown()).nullish(),
forward_user_identity: z.boolean().optional().default(false),
headers: z.record(z.string(), z.unknown()).nullish(),
icon: z.string(),
icon_background: z.string().optional().default(''),
icon_type: z.string(),
identity_mode: z.enum(['idp_token', 'off']).optional().default('off'),
name: z.string(),
server_identifier: z.string(),
server_url: z.string(),
@ -378,12 +358,10 @@ export const zMcpProviderCreatePayload = z.object({
export const zMcpProviderUpdatePayload = z.object({
authentication: z.record(z.string(), z.unknown()).nullish(),
configuration: z.record(z.string(), z.unknown()).nullish(),
forward_user_identity: z.boolean().optional().default(false),
headers: z.record(z.string(), z.unknown()).nullish(),
icon: z.string(),
icon_background: z.string().optional().default(''),
icon_type: z.string(),
identity_mode: z.enum(['idp_token', 'off']).optional().default('off'),
name: z.string(),
provider_id: z.string(),
server_identifier: z.string(),
@ -459,6 +437,32 @@ export const zSwitchWorkspacePayload = z.object({
tenant_id: z.string(),
})
/**
* WorkspaceCustomConfigResponse
*/
export const zWorkspaceCustomConfigResponse = z.object({
remove_webapp_brand: z.boolean().nullish(),
replace_webapp_logo: z.string().nullish(),
})
/**
* TenantInfoResponse
*/
export const zTenantInfoResponse = z.object({
created_at: z.int().nullish(),
custom_config: zWorkspaceCustomConfigResponse.optional(),
id: z.string(),
in_trial: z.boolean().nullish(),
name: z.string().nullish(),
next_credit_reset_date: z.int().nullish(),
plan: z.string().nullish(),
role: z.string().nullish(),
status: z.string().nullish(),
trial_credits: z.int().nullish(),
trial_credits_used: z.int().nullish(),
trial_end_reason: z.string().nullish(),
})
/**
* AccountWithRole
*/

View File

@ -131,14 +131,6 @@ vi.mock('@/service/use-common', () => ({
],
},
}),
useCurrentWorkspace: () => ({
data: {
trial_credits: 1000,
trial_credits_used: 100,
next_credit_reset_date: undefined,
},
isPending: false,
}),
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({

View File

@ -7,11 +7,11 @@ import QuotaPanel from '../quota-panel'
let mockWorkspaceData: {
trial_credits: number
trial_credits_used: number
next_credit_reset_date: string
next_credit_reset_date: number
} | undefined = {
trial_credits: 100,
trial_credits_used: 30,
next_credit_reset_date: '2024-12-31',
next_credit_reset_date: 1735603200,
}
let mockWorkspaceIsPending = false
let mockTrialModels: string[] | undefined = ['langgenius/openai/openai']
@ -32,11 +32,18 @@ vi.mock('@/app/components/base/icons/src/public/llm', () => {
}
})
vi.mock('@/service/use-common', () => ({
useCurrentWorkspace: () => ({
data: mockWorkspaceData,
isPending: mockWorkspaceIsPending,
}),
vi.mock('../use-trial-credits', () => ({
useTrialCredits: () => {
const totalCredits = mockWorkspaceData?.trial_credits ?? 0
const credits = Math.max(totalCredits - (mockWorkspaceData?.trial_credits_used ?? 0), 0)
return {
credits,
totalCredits,
isExhausted: credits <= 0,
isLoading: mockWorkspaceIsPending && !mockWorkspaceData,
nextCreditResetDate: mockWorkspaceData?.next_credit_reset_date,
}
},
}))
const renderQuotaPanel = (ui: ReactElement) => renderWithSystemFeatures(ui, {
@ -78,7 +85,7 @@ describe('QuotaPanel', () => {
mockWorkspaceData = {
trial_credits: 100,
trial_credits_used: 30,
next_credit_reset_date: '2024-12-31',
next_credit_reset_date: 1735603200,
}
mockWorkspaceIsPending = false
mockTrialModels = ['langgenius/openai/openai']
@ -118,7 +125,7 @@ describe('QuotaPanel', () => {
mockWorkspaceData = {
trial_credits: 10,
trial_credits_used: 999,
next_credit_reset_date: '',
next_credit_reset_date: 0,
}
renderQuotaPanel(<QuotaPanel providers={mockProviders} />)

View File

@ -1,20 +1,34 @@
import { renderHook } from '@testing-library/react'
import { useTrialCredits } from '../use-trial-credits'
const mockUseCurrentWorkspace = vi.fn()
const { mockUseQuery } = vi.hoisted(() => ({
mockUseQuery: vi.fn(),
}))
vi.mock('@/service/use-common', () => ({
useCurrentWorkspace: () => mockUseCurrentWorkspace(),
vi.mock('@tanstack/react-query', () => ({
useQuery: () => mockUseQuery(),
}))
vi.mock('@/service/client', () => ({
consoleQuery: {
workspaces: {
current: {
post: {
queryOptions: () => ({ queryKey: ['console', 'workspaces', 'current', 'post'] }),
},
},
},
},
}))
describe('useTrialCredits', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseCurrentWorkspace.mockReturnValue({
mockUseQuery.mockReturnValue({
data: {
trial_credits: 100,
trial_credits_used: 40,
next_credit_reset_date: '2026-04-01',
next_credit_reset_date: 1775001600,
},
isPending: false,
})
@ -29,16 +43,16 @@ describe('useTrialCredits', () => {
totalCredits: 100,
isExhausted: false,
isLoading: false,
nextCreditResetDate: '2026-04-01',
nextCreditResetDate: 1775001600,
})
})
it('should keep the hook out of loading state during a background refetch', () => {
mockUseCurrentWorkspace.mockReturnValue({
mockUseQuery.mockReturnValue({
data: {
trial_credits: 80,
trial_credits_used: 20,
next_credit_reset_date: '2026-05-01',
next_credit_reset_date: 1777593600,
},
isPending: true,
})
@ -53,7 +67,7 @@ describe('useTrialCredits', () => {
describe('when workspace data is missing or exhausted', () => {
it('should report loading while the first workspace request is pending', () => {
mockUseCurrentWorkspace.mockReturnValue({
mockUseQuery.mockReturnValue({
data: undefined,
isPending: true,
})
@ -70,7 +84,7 @@ describe('useTrialCredits', () => {
})
it('should clamp negative remaining credits to zero', () => {
mockUseCurrentWorkspace.mockReturnValue({
mockUseQuery.mockReturnValue({
data: {
trial_credits: 10,
trial_credits_used: 99,

View File

@ -1,6 +1,7 @@
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
import type { ICurrentWorkspace } from '@/models/common'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { consoleQuery } from '@/service/client'
import CreditsExhaustedAlert from './credits-exhausted-alert'
const baseWorkspace: ICurrentWorkspace = {
@ -20,7 +21,7 @@ function createSeededQueryClient(overrides?: Partial<ICurrentWorkspace>) {
const qc = new QueryClient({
defaultOptions: { queries: { refetchOnWindowFocus: false, retry: false } },
})
qc.setQueryData(['common', 'current-workspace'], { ...baseWorkspace, ...overrides })
qc.setQueryData(consoleQuery.workspaces.current.post.queryKey(), { ...baseWorkspace, ...overrides })
return qc
}

View File

@ -1,7 +1,8 @@
import { useCurrentWorkspace } from '@/service/use-common'
import { useQuery } from '@tanstack/react-query'
import { consoleQuery } from '@/service/client'
export const useTrialCredits = () => {
const { data: currentWorkspace, isPending } = useCurrentWorkspace()
const { data: currentWorkspace, isPending } = useQuery(consoleQuery.workspaces.current.post.queryOptions())
const totalCredits = currentWorkspace?.trial_credits ?? 0
const credits = Math.max(totalCredits - (currentWorkspace?.trial_credits_used ?? 0), 0)

View File

@ -1,16 +1,14 @@
'use client'
import type { ICurrentWorkspace } from '@/models/common'
import { Button } from '@langgenius/dify-ui/button'
import { useQuery } from '@tanstack/react-query'
import { useEffect, useState } from 'react'
import Divider from '@/app/components/base/divider'
import { userProfileQueryOptions } from '@/features/account-profile/client'
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
import { post } from '@/service/base'
import { consoleQuery } from '@/service/client'
import { deviceLookup } from '@/service/device-flow'
import { systemFeaturesQueryOptions } from '@/service/system-features'
import { commonQueryKeys } from '@/service/use-common'
import AuthorizeAccount from './components/authorize-account'
import AuthorizeSSO from './components/authorize-sso'
import Chooser from './components/chooser'
@ -52,9 +50,8 @@ export default function DevicePage() {
refetchOnMount: false,
})
const account = userResp?.profile
const { data: currentWorkspace } = useQuery<ICurrentWorkspace>({
queryKey: commonQueryKeys.currentWorkspace,
queryFn: () => post<ICurrentWorkspace>('/workspaces/current'),
const { data: currentWorkspace } = useQuery({
...consoleQuery.workspaces.current.post.queryOptions(),
enabled: !!account && !profileErr,
retry: false,
refetchOnWindowFocus: false,
@ -174,7 +171,7 @@ export default function DevicePage() {
accountEmail={account?.email}
accountName={account?.name}
accountAvatarUrl={account?.avatar_url ?? null}
defaultWorkspace={currentWorkspace?.name}
defaultWorkspace={currentWorkspace?.name ?? undefined}
onApproved={() => setView({ kind: 'success' })}
onDenied={() => setView({ kind: 'error_expired' })}
onError={e => setErrMsg(e)}

View File

@ -23,7 +23,7 @@ import {
useRouter,
useSearchParams,
} from '@/next/navigation'
import { consoleClient } from '@/service/client'
import { consoleClient, consoleQuery } from '@/service/client'
import { switchWorkspace } from '@/service/common'
import { commonQueryKeys } from '@/service/use-common'
import {
@ -129,7 +129,7 @@ const EducationApplyAgeContent = () => {
try {
await switchWorkspace({ url: '/workspaces/switch', body: { tenant_id: tenantId } })
await Promise.all([
queryClient.invalidateQueries({ queryKey: commonQueryKeys.currentWorkspace }),
queryClient.invalidateQueries({ queryKey: consoleQuery.workspaces.current.post.key() }),
queryClient.invalidateQueries({ queryKey: commonQueryKeys.workspaces }),
])
onPlanInfoChanged()

View File

@ -1,8 +1,9 @@
'use client'
import type { PostWorkspacesCurrentResponse } from '@dify/contracts/api/console/workspaces/types.gen'
import type { FC, ReactNode } from 'react'
import type { ICurrentWorkspace, LangGeniusVersionResponse, UserProfileResponse } from '@/models/common'
import { useQueryClient, useSuspenseQuery } from '@tanstack/react-query'
import { useQuery, useQueryClient, useSuspenseQuery } from '@tanstack/react-query'
import { useCallback, useEffect, useMemo } from 'react'
import { setUserId, setUserProperties } from '@/app/components/base/amplitude'
import { setZendeskConversationFields } from '@/app/components/base/zendesk/utils'
@ -17,9 +18,9 @@ import {
} from '@/context/app-context'
import { env } from '@/env'
import { userProfileQueryOptions } from '@/features/account-profile/client'
import { consoleQuery } from '@/service/client'
import { systemFeaturesQueryOptions } from '@/service/system-features'
import {
useCurrentWorkspace,
useLangGeniusVersion,
} from '@/service/use-common'
@ -27,18 +28,52 @@ type AppContextProviderProps = {
children: ReactNode
}
const workspaceRoles = new Set<ICurrentWorkspace['role']>(['owner', 'admin', 'editor', 'dataset_operator', 'normal'])
const resolveWorkspaceRole = (role: PostWorkspacesCurrentResponse['role']): ICurrentWorkspace['role'] => {
if (role && workspaceRoles.has(role as ICurrentWorkspace['role']))
return role as ICurrentWorkspace['role']
return initialWorkspaceInfo.role
}
const normalizeCurrentWorkspace = (workspace?: PostWorkspacesCurrentResponse): ICurrentWorkspace => {
if (!workspace)
return initialWorkspaceInfo
return {
id: workspace.id,
name: workspace.name ?? initialWorkspaceInfo.name,
plan: workspace.plan ?? initialWorkspaceInfo.plan,
status: workspace.status ?? initialWorkspaceInfo.status,
created_at: workspace.created_at ?? initialWorkspaceInfo.created_at,
role: resolveWorkspaceRole(workspace.role),
providers: initialWorkspaceInfo.providers,
trial_credits: workspace.trial_credits ?? initialWorkspaceInfo.trial_credits,
trial_credits_used: workspace.trial_credits_used ?? initialWorkspaceInfo.trial_credits_used,
next_credit_reset_date: workspace.next_credit_reset_date ?? initialWorkspaceInfo.next_credit_reset_date,
trial_end_reason: workspace.trial_end_reason ?? undefined,
custom_config: workspace.custom_config
? {
remove_webapp_brand: workspace.custom_config.remove_webapp_brand ?? undefined,
replace_webapp_logo: workspace.custom_config.replace_webapp_logo ?? undefined,
}
: undefined,
}
}
export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) => {
const queryClient = useQueryClient()
const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions())
const { data: userProfileResp } = useSuspenseQuery(userProfileQueryOptions())
const { data: currentWorkspaceResp, isPending: isLoadingCurrentWorkspace, isFetching: isValidatingCurrentWorkspace } = useCurrentWorkspace()
const { data: currentWorkspaceResp, isPending: isLoadingCurrentWorkspace, isFetching: isValidatingCurrentWorkspace } = useQuery(consoleQuery.workspaces.current.post.queryOptions())
const langGeniusVersionQuery = useLangGeniusVersion(
userProfileResp?.meta.currentVersion,
!systemFeatures.branding.enabled,
)
const userProfile = useMemo<UserProfileResponse>(() => userProfileResp?.profile || userProfilePlaceholder, [userProfileResp?.profile])
const currentWorkspace = useMemo<ICurrentWorkspace>(() => currentWorkspaceResp || initialWorkspaceInfo, [currentWorkspaceResp])
const currentWorkspace = useMemo<ICurrentWorkspace>(() => normalizeCurrentWorkspace(currentWorkspaceResp), [currentWorkspaceResp])
const langGeniusVersionInfo = useMemo<LangGeniusVersionResponse>(() => {
if (!userProfileResp?.meta?.currentVersion || !langGeniusVersionQuery.data)
return initialLangGeniusVersionInfo
@ -64,7 +99,7 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
}, [queryClient])
const mutateCurrentWorkspace = useCallback(() => {
queryClient.invalidateQueries({ queryKey: ['common', 'current-workspace'] })
queryClient.invalidateQueries({ queryKey: consoleQuery.workspaces.current.post.key() })
}, [queryClient])
// #region Zendesk conversation fields

View File

@ -10,7 +10,6 @@ import type {
CodeBasedExtension,
CommonResponse,
FileUploadConfigResponse,
ICurrentWorkspace,
IWorkspace,
LangGeniusVersionResponse,
Member,
@ -26,7 +25,6 @@ const NAME_SPACE = 'common'
export const commonQueryKeys = {
fileUploadConfig: [NAME_SPACE, 'file-upload-config'] as const,
currentWorkspace: [NAME_SPACE, 'current-workspace'] as const,
workspaces: [NAME_SPACE, 'workspaces'] as const,
members: [NAME_SPACE, 'members'] as const,
filePreview: (fileID: string) => [NAME_SPACE, 'file-preview', fileID] as const,
@ -68,13 +66,6 @@ export const useLangGeniusVersion = (currentVersion?: string | null, enabled?: b
})
}
export const useCurrentWorkspace = () => {
return useQuery<ICurrentWorkspace>({
queryKey: commonQueryKeys.currentWorkspace,
queryFn: () => post<ICurrentWorkspace>('/workspaces/current'),
})
}
export const useWorkspaces = () => {
return useQuery<{ workspaces: IWorkspace[] }>({
queryKey: commonQueryKeys.workspaces,