mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 00:38:03 +08:00
Made-with: Cursor # Conflicts: # .devcontainer/post_create_command.sh # api/commands.py # api/core/agent/cot_agent_runner.py # api/core/agent/fc_agent_runner.py # api/core/app/apps/workflow_app_runner.py # api/core/app/entities/queue_entities.py # api/core/app/entities/task_entities.py # api/core/workflow/workflow_entry.py # api/dify_graph/enums.py # api/dify_graph/graph/graph.py # api/dify_graph/graph_events/node.py # api/dify_graph/model_runtime/entities/message_entities.py # api/dify_graph/node_events/node.py # api/dify_graph/nodes/agent/agent_node.py # api/dify_graph/nodes/base/__init__.py # api/dify_graph/nodes/base/entities.py # api/dify_graph/nodes/base/node.py # api/dify_graph/nodes/llm/entities.py # api/dify_graph/nodes/llm/node.py # api/dify_graph/nodes/tool/tool_node.py # api/pyproject.toml # api/uv.lock # web/app/components/base/avatar/__tests__/index.spec.tsx # web/app/components/base/avatar/index.tsx # web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx # web/app/components/base/file-uploader/file-from-link-or-local/index.tsx # web/app/components/base/prompt-editor/index.tsx # web/app/components/datasets/metadata/edit-metadata-batch/modal.tsx # web/app/components/header/account-dropdown/index.spec.tsx # web/app/components/share/text-generation/index.tsx # web/app/components/workflow/block-selector/tool/action-item.tsx # web/app/components/workflow/block-selector/trigger-plugin/action-item.tsx # web/app/components/workflow/hooks/use-edges-interactions.ts # web/app/components/workflow/hooks/use-nodes-interactions.ts # web/app/components/workflow/index.tsx # web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx # web/app/components/workflow/nodes/http/components/key-value/key-value-edit/index.tsx # web/app/components/workflow/nodes/human-input/components/delivery-method/recipient/email-item.tsx # web/app/components/workflow/nodes/loop/use-interactions.ts # web/contract/router.ts # web/env.ts # web/eslint-suppressions.json # web/package.json # web/pnpm-lock.yaml
137 lines
3.9 KiB
Python
137 lines
3.9 KiB
Python
from collections.abc import Callable
|
|
from functools import wraps
|
|
from typing import ParamSpec, TypeVar
|
|
|
|
from flask import current_app, request
|
|
from flask_login import user_logged_in
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from extensions.ext_database import db
|
|
from libs.login import current_user
|
|
from models.account import Tenant
|
|
from models.model import DefaultEndUserSessionID, EndUser
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
|
|
|
|
class TenantUserPayload(BaseModel):
|
|
tenant_id: str
|
|
user_id: str
|
|
|
|
|
|
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|
"""
|
|
Get current user
|
|
|
|
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
|
As a result, it could only be considered as an end user id.
|
|
"""
|
|
if not user_id:
|
|
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
|
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
|
try:
|
|
with Session(db.engine) as session:
|
|
user_model = None
|
|
|
|
if is_anonymous:
|
|
user_model = (
|
|
session.query(EndUser)
|
|
.where(
|
|
EndUser.session_id == user_id,
|
|
EndUser.tenant_id == tenant_id,
|
|
)
|
|
.first()
|
|
)
|
|
else:
|
|
user_model = (
|
|
session.query(EndUser)
|
|
.where(
|
|
EndUser.id == user_id,
|
|
EndUser.tenant_id == tenant_id,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if not user_model:
|
|
user_model = EndUser(
|
|
tenant_id=tenant_id,
|
|
type="service_api",
|
|
is_anonymous=is_anonymous,
|
|
session_id=user_id,
|
|
)
|
|
session.add(user_model)
|
|
session.commit()
|
|
session.refresh(user_model)
|
|
|
|
except Exception:
|
|
raise ValueError("user not found")
|
|
|
|
return user_model
|
|
|
|
|
|
def get_user_tenant(view_func: Callable[P, R]):
|
|
@wraps(view_func)
|
|
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
|
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
|
|
user_id = payload.user_id
|
|
tenant_id = payload.tenant_id
|
|
|
|
if not tenant_id:
|
|
raise ValueError("tenant_id is required")
|
|
|
|
if not user_id:
|
|
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
|
|
|
try:
|
|
tenant_model = (
|
|
db.session.query(Tenant)
|
|
.where(
|
|
Tenant.id == tenant_id,
|
|
)
|
|
.first()
|
|
)
|
|
except Exception:
|
|
raise ValueError("tenant not found")
|
|
|
|
if not tenant_model:
|
|
raise ValueError("tenant not found")
|
|
|
|
kwargs["tenant_model"] = tenant_model
|
|
|
|
user = get_user(tenant_id, user_id)
|
|
kwargs["user_model"] = user
|
|
|
|
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
|
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
|
|
|
return view_func(*args, **kwargs)
|
|
|
|
return decorated_view
|
|
|
|
|
|
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
|
def decorator(view_func: Callable[P, R]):
|
|
@wraps(view_func)
|
|
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
|
try:
|
|
data = request.get_json()
|
|
except Exception:
|
|
raise ValueError("invalid json")
|
|
|
|
try:
|
|
payload = payload_type.model_validate(data)
|
|
except Exception as e:
|
|
raise ValueError(f"invalid payload: {str(e)}")
|
|
|
|
kwargs["payload"] = payload
|
|
return view_func(*args, **kwargs)
|
|
|
|
return decorated_view
|
|
|
|
if view is None:
|
|
return decorator
|
|
else:
|
|
return decorator(view)
|