mirror of
https://github.com/langgenius/dify.git
synced 2026-05-27 04:16:16 +08:00
Replace the single mutable-context Pipeline with a two-phase, condition-driven system dispatched by token type. New architecture: - TokenType(StrEnum) replaces source: str on AuthContext / TokenKind - AuthPipeline: pure prepare→auth step runner; no guard() - PipelineRoute: binds AuthPipeline to an optional required_edition gate - PipelineRouter: single guard() entry point; runs edition/license/token-type pre-gates then dispatches to the registered pipeline for the token type - Cond / When: composable predicates for conditional step dispatch - AuthData: frozen Pydantic model produced by the prepare phase; carries token_id so endpoints don't need to call get_auth_ctx() for identity fields - Edition enum + current_edition(): CE / EE / SAAS discriminator Two pipelines in composition.py: - account_pipeline — OAUTH_ACCOUNT tokens - external_sso_pipeline — OAUTH_EXTERNAL_SSO tokens (EE enforced at route level) All /openapi/v1 endpoints migrated to auth_router.guard(). Old context.py, steps.py, strategies.py, surface_gate.py deleted. WORKSPACE_READ scope added; cached_verdicts renamed to membership_cache.
173 lines
6.1 KiB
Python
173 lines
6.1 KiB
Python
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections.abc import Callable, Iterator
|
|
from contextlib import contextmanager
|
|
from typing import Any
|
|
|
|
from flask import request
|
|
from flask_restx import Resource
|
|
from pydantic import ValidationError
|
|
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
|
|
|
import services
|
|
from controllers.openapi import openapi_ns
|
|
from controllers.openapi._audit import emit_app_run
|
|
from controllers.openapi._models import AppRunRequest
|
|
from controllers.openapi.auth.composition import auth_router
|
|
from controllers.openapi.auth.data import AuthData
|
|
from controllers.service_api.app.error import (
|
|
AppUnavailableError,
|
|
CompletionRequestError,
|
|
ConversationCompletedError,
|
|
ProviderModelCurrentlyNotSupportError,
|
|
ProviderNotInitializeError,
|
|
ProviderQuotaExceededError,
|
|
)
|
|
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
from core.errors.error import (
|
|
ModelCurrentlyNotSupportError,
|
|
ProviderTokenNotInitError,
|
|
QuotaExceededError,
|
|
)
|
|
from extensions.ext_redis import redis_client
|
|
from graphon.graph_engine.manager import GraphEngineManager
|
|
from graphon.model_runtime.errors.invoke import InvokeError
|
|
from libs import helper
|
|
from libs.oauth_bearer import Scope
|
|
from models.model import App, AppMode
|
|
from services.app_generate_service import AppGenerateService
|
|
from services.errors.app import (
|
|
IsDraftWorkflowError,
|
|
WorkflowIdFormatError,
|
|
WorkflowNotFoundError,
|
|
)
|
|
from services.errors.llm import InvokeRateLimitError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@contextmanager
|
|
def _translate_service_errors() -> Iterator[None]:
|
|
try:
|
|
yield
|
|
except WorkflowNotFoundError as ex:
|
|
raise NotFound(str(ex))
|
|
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
|
raise BadRequest(str(ex))
|
|
except services.errors.conversation.ConversationNotExistsError:
|
|
raise NotFound("Conversation Not Exists.")
|
|
except services.errors.conversation.ConversationCompletedError:
|
|
raise ConversationCompletedError()
|
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
|
logger.exception("App model config broken.")
|
|
raise AppUnavailableError()
|
|
except ProviderTokenNotInitError as ex:
|
|
raise ProviderNotInitializeError(ex.description)
|
|
except QuotaExceededError:
|
|
raise ProviderQuotaExceededError()
|
|
except ModelCurrentlyNotSupportError:
|
|
raise ProviderModelCurrentlyNotSupportError()
|
|
except InvokeRateLimitError as ex:
|
|
raise InvokeRateLimitHttpError(ex.description)
|
|
except InvokeError as e:
|
|
raise CompletionRequestError(e.description)
|
|
|
|
|
|
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
|
return AppGenerateService.generate(
|
|
app_model=app,
|
|
user=caller,
|
|
args=args,
|
|
invoke_from=InvokeFrom.OPENAPI,
|
|
streaming=streaming,
|
|
)
|
|
|
|
|
|
def _run_chat(app: App, caller: Any, payload: AppRunRequest):
|
|
if not payload.query or not payload.query.strip():
|
|
raise UnprocessableEntity("query_required_for_chat")
|
|
args = payload.model_dump(exclude_none=True)
|
|
with _translate_service_errors():
|
|
return _generate(app, caller, args, streaming=True)
|
|
|
|
|
|
def _run_completion(app: App, caller: Any, payload: AppRunRequest):
|
|
args = payload.model_dump(exclude_none=True)
|
|
args["auto_generate_name"] = False
|
|
args.setdefault("query", "")
|
|
with _translate_service_errors():
|
|
return _generate(app, caller, args, streaming=True)
|
|
|
|
|
|
def _run_workflow(app: App, caller: Any, payload: AppRunRequest):
|
|
if payload.query is not None:
|
|
raise UnprocessableEntity("query_not_supported_for_workflow")
|
|
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
|
with _translate_service_errors():
|
|
return _generate(app, caller, args, streaming=True)
|
|
|
|
|
|
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
|
AppMode.CHAT: _run_chat,
|
|
AppMode.AGENT_CHAT: _run_chat,
|
|
AppMode.ADVANCED_CHAT: _run_chat,
|
|
AppMode.COMPLETION: _run_completion,
|
|
AppMode.WORKFLOW: _run_workflow,
|
|
}
|
|
|
|
|
|
@openapi_ns.route("/apps/<string:app_id>/run")
|
|
class AppRunApi(Resource):
|
|
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
|
@openapi_ns.response(200, "Run result (SSE stream)")
|
|
@auth_router.guard(scope=Scope.APPS_RUN)
|
|
def post(self, app_id: str, *, auth_data: AuthData):
|
|
app_model = auth_data.app
|
|
caller = auth_data.caller
|
|
caller_kind = auth_data.caller_kind
|
|
body = request.get_json(silent=True) or {}
|
|
try:
|
|
payload = AppRunRequest.model_validate(body)
|
|
except ValidationError as exc:
|
|
raise UnprocessableEntity(exc.json())
|
|
|
|
handler = _DISPATCH.get(app_model.mode)
|
|
if handler is None:
|
|
raise UnprocessableEntity("mode_not_runnable")
|
|
|
|
try:
|
|
stream_obj = handler(app_model, caller, payload)
|
|
except HTTPException:
|
|
raise
|
|
except Exception:
|
|
logger.exception("internal server error.")
|
|
raise InternalServerError()
|
|
|
|
emit_app_run(
|
|
app_id=app_model.id,
|
|
tenant_id=app_model.tenant_id,
|
|
caller_kind=caller_kind,
|
|
mode=str(app_model.mode),
|
|
surface="apps",
|
|
)
|
|
|
|
return helper.compact_generate_response(stream_obj)
|
|
|
|
|
|
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
|
class AppRunTaskStopApi(Resource):
|
|
@openapi_ns.response(200, "Task stopped")
|
|
@auth_router.guard(scope=Scope.APPS_RUN)
|
|
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
|
app_model = auth_data.app
|
|
caller = auth_data.caller
|
|
caller_kind = auth_data.caller_kind
|
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
|
GraphEngineManager(redis_client).send_stop_command(task_id)
|
|
return {"result": "success"}
|