feat(dify-cli): session level tool white list

This commit is contained in:
Harry
2026-01-26 18:09:14 +08:00
parent a9e1394011
commit 89eb7b17db
9 changed files with 152 additions and 38 deletions

View File

@ -2,12 +2,12 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import current_app, request
from flask import current_app, g, request
from flask_login import user_logged_in
from pydantic import BaseModel
from sqlalchemy.orm import Session
from core.session.cli_api import CliApiSession, CliApiSessionManager
from core.session.cli_api import CliApiSession, CliContext
from extensions.ext_database import db
from libs.login import current_user
from models.account import Tenant
@ -75,22 +75,13 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
def get_cli_user_tenant(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
session_id = request.headers.get("X-Cli-Api-Session-Id")
session: CliApiSession | None = getattr(g, "cli_api_session", None)
if session is None:
raise ValueError("session not found")
if session_id:
session: CliApiSession | None = CliApiSessionManager().get(session_id)
if not session:
raise ValueError("session not found")
user_id = session.user_id
tenant_id = session.tenant_id
else:
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
user_id = payload.user_id
tenant_id = payload.tenant_id
if not tenant_id:
raise ValueError("tenant_id is required")
user_id = session.user_id
tenant_id = session.tenant_id
cli_context = CliContext.model_validate(session.context)
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
@ -110,11 +101,10 @@ def get_cli_user_tenant(view_func: Callable[P, R]):
raise ValueError("tenant not found")
kwargs["tenant_model"] = tenant_model
kwargs["user_model"] = get_user(tenant_id, user_id)
kwargs["cli_context"] = cli_context
user = get_user(tenant_id, user_id)
kwargs["user_model"] = user
current_app.login_manager._update_request_context_with_user(user) # type: ignore
current_app.login_manager._update_request_context_with_user(kwargs["user_model"]) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
return view_func(*args, **kwargs)