mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
Merge branch 'main' into deploy/dev
This commit is contained in:
@ -1,8 +1,9 @@
|
||||
import base64
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -15,22 +16,8 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SubscriptionQuery(BaseModel):
|
||||
plan: str = Field(..., description="Subscription plan")
|
||||
interval: str = Field(..., description="Billing interval")
|
||||
|
||||
@field_validator("plan")
|
||||
@classmethod
|
||||
def validate_plan(cls, value: str) -> str:
|
||||
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
|
||||
raise ValueError("Invalid plan")
|
||||
return value
|
||||
|
||||
@field_validator("interval")
|
||||
@classmethod
|
||||
def validate_interval(cls, value: str) -> str:
|
||||
if value not in {"month", "year"}:
|
||||
raise ValueError("Invalid interval")
|
||||
return value
|
||||
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
|
||||
interval: Literal["month", "year"] = Field(..., description="Billing interval")
|
||||
|
||||
|
||||
class PartnerTenantsPayload(BaseModel):
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
@ -10,19 +8,19 @@ from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUID
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -10,10 +12,20 @@ from models import TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
class LoadBalancingCredentialPayload(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict[str, object]
|
||||
|
||||
|
||||
register_schema_models(console_ns, LoadBalancingCredentialPayload)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
||||
)
|
||||
class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# validate model load balancing credentials
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
model_load_balancing_service.validate_load_balancing_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
model=payload.model,
|
||||
model_type=payload.model_type,
|
||||
credentials=payload.credentials,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
||||
)
|
||||
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# validate model load balancing config credentials
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
model_load_balancing_service.validate_load_balancing_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
model=payload.model,
|
||||
model_type=payload.model_type,
|
||||
credentials=payload.credentials,
|
||||
config_id=config_id,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import io
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from flask import make_response, redirect, request, send_file
|
||||
@ -17,8 +18,8 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
@ -40,6 +41,8 @@ from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
if not url:
|
||||
@ -945,8 +948,8 @@ class ToolProviderMCPApi(Resource):
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider in transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
# 1) Create provider in a short transaction (no network I/O inside)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=tenant_id,
|
||||
@ -962,8 +965,26 @@ class ToolProviderMCPApi(Resource):
|
||||
authentication=authentication,
|
||||
)
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||
# Perform network I/O outside any DB session to avoid holding locks.
|
||||
try:
|
||||
reconnect = MCPToolManageService.reconnect_with_url(
|
||||
server_url=args["server_url"],
|
||||
headers=args.get("headers") or {},
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
)
|
||||
# Update just-created provider with authed/tools in a new short transaction
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
|
||||
db_provider.authed = reconnect.authed
|
||||
db_provider.tools = reconnect.tools
|
||||
|
||||
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
|
||||
except Exception:
|
||||
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
|
||||
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@ -1011,9 +1032,6 @@ class ToolProviderMCPApi(Resource):
|
||||
validation_result=validation_result,
|
||||
)
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(parser_mcp_delete)
|
||||
@ -1028,9 +1046,6 @@ class ToolProviderMCPApi(Resource):
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -1081,8 +1096,6 @@ class ToolMCPAuthApi(Resource):
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
# Invalidate cache after updating credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
try:
|
||||
@ -1096,22 +1109,16 @@ class ToolMCPAuthApi(Resource):
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
# Invalidate cache after auth actions may have updated provider state
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
# Invalidate cache after clearing credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except (MCPError, ValueError) as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
# Invalidate cache after clearing credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
@ -460,9 +459,8 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def get(self, _, dataset_id):
|
||||
def get(self, _):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
@ -482,8 +480,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
def post(self, _):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -506,8 +503,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
def patch(self, _):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@ -533,9 +529,8 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@edit_permission_required
|
||||
def delete(self, _, dataset_id):
|
||||
def delete(self, _):
|
||||
"""Delete a knowledge type tag."""
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
@ -555,8 +550,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
def post(self, _):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -580,8 +574,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
def post(self, _):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -604,7 +597,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
Reference in New Issue
Block a user