mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
feat(enterprise): auto-join newly registered accounts to the default workspace (#32308)
Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
This commit is contained in:
@ -39,6 +39,9 @@ class BaseRequest:
|
||||
endpoint: str,
|
||||
json: Any | None = None,
|
||||
params: Mapping[str, Any] | None = None,
|
||||
*,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
raise_for_status: bool = False,
|
||||
) -> Any:
|
||||
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
@ -53,7 +56,16 @@ class BaseRequest:
|
||||
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||
|
||||
with httpx.Client(mounts=mounts) as client:
|
||||
response = client.request(method, url, json=json, params=params, headers=headers)
|
||||
# IMPORTANT:
|
||||
# - In httpx, passing timeout=None disables timeouts (infinite) and overrides the library default.
|
||||
# - To preserve httpx's default timeout behavior for existing call sites, only pass the kwarg when set.
|
||||
request_kwargs: dict[str, Any] = {"json": json, "params": params, "headers": headers}
|
||||
if timeout is not None:
|
||||
request_kwargs["timeout"] = timeout
|
||||
|
||||
response = client.request(method, url, **request_kwargs)
|
||||
if raise_for_status:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
|
||||
@ -1,9 +1,16 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
|
||||
|
||||
|
||||
class WebAppSettings(BaseModel):
|
||||
access_mode: str = Field(
|
||||
@ -30,6 +37,55 @@ class WorkspacePermission(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class DefaultWorkspaceJoinResult(BaseModel):
|
||||
"""
|
||||
Result of ensuring an account is a member of the enterprise default workspace.
|
||||
|
||||
- joined=True is idempotent (already a member also returns True)
|
||||
- joined=False means enterprise default workspace is not configured or invalid/archived
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(default="", alias="workspaceId")
|
||||
joined: bool
|
||||
message: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
|
||||
if self.joined and not self.workspace_id:
|
||||
raise ValueError("workspace_id must be non-empty when joined is True")
|
||||
return self
|
||||
|
||||
|
||||
def try_join_default_workspace(account_id: str) -> None:
|
||||
"""
|
||||
Enterprise-only side-effect: ensure account is a member of the default workspace.
|
||||
|
||||
This is a best-effort integration. Failures must not block user registration.
|
||||
"""
|
||||
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
|
||||
try:
|
||||
result = EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
if result.joined:
|
||||
logger.info(
|
||||
"Joined enterprise default workspace for account %s (workspace_id=%s)",
|
||||
account_id,
|
||||
result.workspace_id,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Skipped joining enterprise default workspace for account %s (message=%s)",
|
||||
account_id,
|
||||
result.message,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to join enterprise default workspace for account %s", account_id, exc_info=True)
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
@ -39,6 +95,34 @@ class EnterpriseService:
|
||||
def get_workspace_info(cls, tenant_id: str):
|
||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||
|
||||
@classmethod
|
||||
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
|
||||
"""
|
||||
Call enterprise inner API to add an account to the default workspace.
|
||||
|
||||
NOTE: EnterpriseRequest.base_url is expected to already include the `/inner/api` prefix,
|
||||
so the endpoint here is `/default-workspace/members`.
|
||||
"""
|
||||
|
||||
# Ensure we are sending a UUID-shaped string (enterprise side validates too).
|
||||
try:
|
||||
uuid.UUID(account_id)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"account_id must be a valid UUID: {account_id}") from e
|
||||
|
||||
data = EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
|
||||
raise_for_status=True,
|
||||
)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Invalid response format from enterprise default workspace API")
|
||||
if "joined" not in data or "message" not in data:
|
||||
raise ValueError("Invalid response payload from enterprise default workspace API")
|
||||
return DefaultWorkspaceJoinResult.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def get_app_sso_settings_last_update_time(cls) -> datetime:
|
||||
data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time")
|
||||
|
||||
Reference in New Issue
Block a user