Compare commits

..

1 Commits

Author SHA1 Message Date
5874b920b2 fix: code owners 2025-11-28 14:36:30 +08:00
188 changed files with 7525 additions and 24731 deletions

8
.github/CODEOWNERS vendored
View File

@ -18,10 +18,10 @@ api/core/workflow/node_events/ @laipz8200 @QuantumGhost
api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
api/core/workflow/nodes/agent/ @Nov1c444
api/core/workflow/nodes/iteration/ @Nov1c444
api/core/workflow/nodes/loop/ @Nov1c444
api/core/workflow/nodes/llm/ @Nov1c444
api/core/workflow/nodes/agent/ Nov1c444
api/core/workflow/nodes/iteration/ Nov1c444
api/core/workflow/nodes/loop/ Nov1c444
api/core/workflow/nodes/llm/ Nov1c444
# Backend - RAG (Retrieval Augmented Generation)
api/core/rag/ @JohnJyong

View File

@ -1,8 +1,6 @@
import logging
import time
from opentelemetry.trace import get_current_span
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp
@ -28,25 +26,8 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
@dify_app.after_request
def add_trace_id_header(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_id_header
return dify_app
@ -70,7 +51,6 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
ext_logging,
@ -95,7 +75,6 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
ext_import_modules,
ext_orjson,
ext_forward_refs,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,

View File

@ -553,10 +553,7 @@ class LoggingConfig(BaseSettings):
LOG_FORMAT: str = Field(
description="Format string for log messages",
default=(
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
),
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
)
LOG_DATEFORMAT: str | None = Field(

View File

@ -1,23 +1,16 @@
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, reqparse
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateQuery(BaseModel):
app_mode: str = Field(..., description="Application mode")
model_mode: str = Field(..., description="Model mode")
has_context: str = Field(default="true", description="Whether has context")
model_name: str = Field(..., description="Model name")
console_ns.schema_model(
AdvancedPromptTemplateQuery.__name__,
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@ -25,7 +18,7 @@ console_ns.schema_model(
class AdvancedPromptTemplateList(Resource):
@console_ns.doc("get_advanced_prompt_templates")
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
)
@ -34,6 +27,6 @@ class AdvancedPromptTemplateList(Resource):
@login_required
@account_initialization_required
def get(self):
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
return AdvancedPromptTemplateService.get_prompt(args)

View File

@ -1,12 +1,9 @@
import uuid
from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from werkzeug.exceptions import BadRequest, abort
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -39,130 +36,6 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
if not value:
return None
if isinstance(value, str):
items = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(value, list):
items = [str(item).strip() for item in value if item and str(item).strip()]
else:
raise TypeError("Unsupported tag_ids type.")
if not items:
return None
try:
return [str(uuid.UUID(item)) for item in items]
except ValueError as exc:
raise ValueError("Invalid UUID format in tag_ids.") from exc
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
class AppNamePayload(BaseModel):
name: str = Field(..., min_length=1, description="Name to check")
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
icon_background: str | None = Field(default=None, description="Icon background color")
class AppSiteStatusPayload(BaseModel):
enable_site: bool = Field(..., description="Enable or disable site")
class AppApiStatusPayload(BaseModel):
enable_api: bool = Field(..., description="Enable or disable API")
class AppTracePayload(BaseModel):
enabled: bool = Field(..., description="Enable or disable tracing")
tracing_provider: str = Field(..., description="Tracing provider")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AppListQuery)
reg(CreateAppPayload)
reg(UpdateAppPayload)
reg(CopyAppPayload)
reg(AppExportQuery)
reg(AppNamePayload)
reg(AppIconPayload)
reg(AppSiteStatusPayload)
reg(AppApiStatusPayload)
reg(AppTracePayload)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base models first
@ -274,7 +147,22 @@ app_pagination_model = console_ns.model(
class AppListApi(Resource):
@console_ns.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect(console_ns.models[AppListQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
.add_argument(
"mode",
type=str,
location="args",
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
default="all",
help="App mode filter",
)
.add_argument("name", type=str, location="args", help="Filter by app name")
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
)
@console_ns.response(200, "Success", app_pagination_model)
@setup_required
@login_required
@ -284,12 +172,42 @@ class AppListApi(Resource):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args_dict = args.model_dump()
def uuid_list(value):
try:
return [str(uuid.UUID(v)) for v in value.split(",")]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"mode",
type=str,
choices=[
"completion",
"chat",
"advanced-chat",
"workflow",
"agent-chat",
"channel",
"all",
],
default="all",
location="args",
required=False,
)
.add_argument("name", type=str, location="args", required=False)
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
)
args = parser.parse_args()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@ -324,13 +242,10 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
try:
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
continue
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
@ -339,7 +254,19 @@ class AppListApi(Resource):
@console_ns.doc("create_app")
@console_ns.doc(description="Create a new application")
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
@console_ns.expect(
console_ns.model(
"CreateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(201, "App created successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@ -352,10 +279,22 @@ class AppListApi(Resource):
def post(self):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
args = CreateAppPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
app_service = AppService()
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
app = app_service.create_app(current_tenant_id, args, current_user)
return app, 201
@ -387,7 +326,20 @@ class AppApi(Resource):
@console_ns.doc("update_app")
@console_ns.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
@console_ns.expect(
console_ns.model(
"UpdateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
"max_active_requests": fields.Integer(description="Maximum active requests"),
},
)
)
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@ -399,18 +351,28 @@ class AppApi(Resource):
@marshal_with(app_detail_with_site_model)
def put(self, app_model):
"""Update app"""
args = UpdateAppPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
.add_argument("max_active_requests", type=int, location="json")
)
args = parser.parse_args()
app_service = AppService()
args_dict: AppService.ArgsDict = {
"name": args.name,
"description": args.description or "",
"icon_type": args.icon_type or "",
"icon": args.icon or "",
"icon_background": args.icon_background or "",
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
"max_active_requests": args.max_active_requests or 0,
"name": args["name"],
"description": args.get("description", ""),
"icon_type": args.get("icon_type", ""),
"icon": args.get("icon", ""),
"icon_background": args.get("icon_background", ""),
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
"max_active_requests": args.get("max_active_requests", 0),
}
app_model = app_service.update_app(app_model, args_dict)
@ -439,7 +401,18 @@ class AppCopyApi(Resource):
@console_ns.doc("copy_app")
@console_ns.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
@console_ns.expect(
console_ns.model(
"CopyAppRequest",
{
"name": fields.String(description="Name for the copied app"),
"description": fields.String(description="Description for the copied app"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -453,7 +426,15 @@ class AppCopyApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = CopyAppPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
with Session(db.engine) as session:
import_service = AppDslService(session)
@ -462,11 +443,11 @@ class AppCopyApi(Resource):
account=current_user,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=args.name,
description=args.description,
icon_type=args.icon_type,
icon=args.icon,
icon_background=args.icon_background,
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
)
session.commit()
@ -481,7 +462,11 @@ class AppExportApi(Resource):
@console_ns.doc("export_app")
@console_ns.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"})
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
)
@console_ns.response(
200,
"App exported successfully",
@ -495,23 +480,30 @@ class AppExportApi(Resource):
@edit_permission_required
def get(self, app_model):
"""Export app"""
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# Add include_secret params
parser = (
reqparse.RequestParser()
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
.add_argument("workflow_id", type=str, location="args")
)
args = parser.parse_args()
return {
"data": AppDslService.export_dsl(
app_model=app_model,
include_secret=args.include_secret,
workflow_id=args.workflow_id,
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
)
}
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")
@console_ns.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
@console_ns.expect(parser)
@console_ns.response(200, "Name availability checked")
@setup_required
@login_required
@ -520,10 +512,10 @@ class AppNameApi(Resource):
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
args = AppNamePayload.model_validate(console_ns.payload)
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_name(app_model, args.name)
app_model = app_service.update_app_name(app_model, args["name"])
return app_model
@ -533,7 +525,16 @@ class AppIconApi(Resource):
@console_ns.doc("update_app_icon")
@console_ns.doc(description="Update application icon")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
@console_ns.expect(
console_ns.model(
"AppIconRequest",
{
"icon": fields.String(required=True, description="Icon data"),
"icon_type": fields.String(description="Icon type"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(200, "Icon updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -543,10 +544,15 @@ class AppIconApi(Resource):
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
args = AppIconPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
return app_model
@ -556,7 +562,11 @@ class AppSiteStatus(Resource):
@console_ns.doc("update_app_site_status")
@console_ns.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
@console_ns.expect(
console_ns.model(
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
)
)
@console_ns.response(200, "Site status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -566,10 +576,11 @@ class AppSiteStatus(Resource):
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
args = AppSiteStatusPayload.model_validate(console_ns.payload)
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.enable_site)
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
return app_model
@ -579,7 +590,11 @@ class AppApiStatus(Resource):
@console_ns.doc("update_app_api_status")
@console_ns.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
@console_ns.expect(
console_ns.model(
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
)
)
@console_ns.response(200, "API status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -589,10 +604,11 @@ class AppApiStatus(Resource):
@get_app_model
@marshal_with(app_detail_model)
def post(self, app_model):
args = AppApiStatusPayload.model_validate(console_ns.payload)
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.enable_api)
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
return app_model
@ -615,7 +631,15 @@ class AppTraceApi(Resource):
@console_ns.doc("update_app_trace")
@console_ns.doc(description="Update app tracing configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
@console_ns.expect(
console_ns.model(
"AppTraceRequest",
{
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
"tracing_provider": fields.String(required=True, description="Tracing provider"),
},
)
)
@console_ns.response(200, "Trace configuration updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -624,12 +648,17 @@ class AppTraceApi(Resource):
@edit_permission_required
def post(self, app_id):
# add app trace
args = AppTracePayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("enabled", type=bool, required=True, location="json")
.add_argument("tracing_provider", type=str, required=True, location="json")
)
args = parser.parse_args()
OpsTraceManager.update_app_tracing_config(
app_id=app_id,
enabled=args.enabled,
tracing_provider=args.tracing_provider,
enabled=args["enabled"],
tracing_provider=args["tracing_provider"],
)
return {"result": "success"}

View File

@ -1,9 +1,7 @@
import logging
from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
@ -37,41 +35,6 @@ from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
model_config_data: dict[str, Any] = Field(..., alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
class CompletionMessagePayload(BaseMessagePayload):
query: str = Field(default="", description="Query text")
class ChatMessagePayload(BaseMessagePayload):
query: str = Field(..., description="User query")
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
console_ns.schema_model(
CompletionMessagePayload.__name__,
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# define completion message api for user
@ -80,7 +43,19 @@ class CompletionMessageApi(Resource):
@console_ns.doc("create_completion_message")
@console_ns.doc(description="Generate completion message for debugging")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
@console_ns.expect(
console_ns.model(
"CompletionMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(description="Query text", default=""),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@console_ns.response(200, "Completion generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App not found")
@ -89,10 +64,18 @@ class CompletionMessageApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model):
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args_model.response_mode != "blocking"
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
try:
@ -154,7 +137,21 @@ class ChatMessageApi(Resource):
@console_ns.doc("create_chat_message")
@console_ns.doc(description="Generate chat message for debugging")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
@console_ns.expect(
console_ns.model(
"ChatMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(required=True, description="User query"),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"conversation_id": fields.String(description="Conversation ID"),
"parent_message_id": fields.String(description="Parent message ID"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@console_ns.response(200, "Chat message generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App or conversation not found")
@ -164,10 +161,20 @@ class ChatMessageApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model):
args_model = ChatMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args_model.response_mode != "blocking"
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)

View File

@ -1,9 +1,7 @@
from typing import Literal
import sqlalchemy as sa
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask import abort
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import NotFound
@ -16,53 +14,13 @@ from extensions.ext_database import db
from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import TimestampField
from libs.helper import DatetimeString, TimestampField
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseConversationQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword")
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
default="all", description="Annotation status filter"
)
page: int = Field(default=1, ge=1, le=99999, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
class CompletionConversationQuery(BaseConversationQuery):
pass
class ChatConversationQuery(BaseConversationQuery):
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort field and direction"
)
console_ns.schema_model(
CompletionConversationQuery.__name__,
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatConversationQuery.__name__,
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -325,7 +283,22 @@ class CompletionConversationApi(Resource):
@console_ns.doc("list_completion_conversations")
@console_ns.doc(description="Get completion conversations with pagination and filtering")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
)
@console_ns.response(200, "Success", conversation_pagination_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -336,17 +309,32 @@ class CompletionConversationApi(Resource):
@edit_permission_required
def get(self, app_model):
current_user, _ = current_account_with_tenant()
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
)
if args.keyword:
if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike(f"%{args.keyword}%"),
Message.answer.ilike(f"%{args.keyword}%"),
Message.query.ilike(f"%{args['keyword']}%"),
Message.answer.ilike(f"%{args['keyword']}%"),
)
)
@ -354,7 +342,7 @@ class CompletionConversationApi(Resource):
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -366,11 +354,11 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
if args.annotation_status == "annotated":
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args.annotation_status == "not_annotated":
elif args["annotation_status"] == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
@ -379,7 +367,7 @@ class CompletionConversationApi(Resource):
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return conversations
@ -431,7 +419,31 @@ class ChatConversationApi(Resource):
@console_ns.doc("list_chat_conversations")
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
.add_argument(
"sort_by",
type=str,
location="args",
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
default="-updated_at",
help="Sort field and direction",
)
)
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -442,7 +454,31 @@ class ChatConversationApi(Resource):
@edit_permission_required
def get(self, app_model):
current_user, _ = current_account_with_tenant()
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
)
args = parser.parse_args()
subquery = (
db.session.query(
@ -454,8 +490,8 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args.keyword:
keyword_filter = f"%{args.keyword}%"
if args["keyword"]:
keyword_filter = f"%{args['keyword']}%"
query = (
query.join(
Message,
@ -478,12 +514,12 @@ class ChatConversationApi(Resource):
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
match args.sort_by:
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
@ -491,27 +527,35 @@ class ChatConversationApi(Resource):
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
match args.sort_by:
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args.annotation_status == "annotated":
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args.annotation_status == "not_annotated":
elif args["annotation_status"] == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
if args["message_count_gte"] and args["message_count_gte"] >= 1:
query = (
query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"])
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args.sort_by:
match args["sort_by"]:
case "created_at":
query = query.order_by(Conversation.created_at.asc())
case "-created_at":
@ -523,7 +567,7 @@ class ChatConversationApi(Resource):
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return conversations

View File

@ -1,6 +1,4 @@
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -16,18 +14,6 @@ from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
@ -47,7 +33,11 @@ class ConversationVariablesApi(Resource):
@console_ns.doc("get_conversation_variables")
@console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
@console_ns.expect(
console_ns.parser().add_argument(
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
)
)
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@setup_required
@login_required
@ -55,14 +45,18 @@ class ConversationVariablesApi(Resource):
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model)
def get(self, app_model):
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
args = parser.parse_args()
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
if args["conversation_id"]:
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
else:
raise ValueError("conversation_id is required")
# NOTE: This is a temporary solution to avoid performance issues.
page = 1

View File

@ -1,8 +1,6 @@
from collections.abc import Sequence
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, reqparse
from controllers.console import console_ns
from controllers.console.app.error import (
@ -23,54 +21,21 @@ from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(RuleGeneratePayload)
reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
@console_ns.route("/rule-generate")
class RuleGenerateApi(Resource):
@console_ns.doc("generate_rule_config")
@console_ns.doc(description="Generate rule configuration using LLM")
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
@console_ns.expect(
console_ns.model(
"RuleGenerateRequest",
{
"instruction": fields.String(required=True, description="Rule generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
},
)
)
@console_ns.response(200, "Rule configuration generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@ -78,15 +43,21 @@ class RuleGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
args = RuleGeneratePayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -104,7 +75,19 @@ class RuleGenerateApi(Resource):
class RuleCodeGenerateApi(Resource):
@console_ns.doc("generate_rule_code")
@console_ns.doc(description="Generate code rules using LLM")
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
@console_ns.expect(
console_ns.model(
"RuleCodeGenerateRequest",
{
"instruction": fields.String(required=True, description="Code generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
"code_language": fields.String(
default="javascript", description="Programming language for code generation"
),
},
)
)
@console_ns.response(200, "Code rules generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@ -112,15 +95,22 @@ class RuleCodeGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.code_language,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -138,7 +128,15 @@ class RuleCodeGenerateApi(Resource):
class RuleStructuredOutputGenerateApi(Resource):
@console_ns.doc("generate_structured_output")
@console_ns.doc(description="Generate structured output rules using LLM")
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
@console_ns.expect(
console_ns.model(
"StructuredOutputGenerateRequest",
{
"instruction": fields.String(required=True, description="Structured output generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
},
)
)
@console_ns.response(200, "Structured output generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@ -146,14 +144,19 @@ class RuleStructuredOutputGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
instruction=args["instruction"],
model_config=args["model_config"],
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -171,7 +174,20 @@ class RuleStructuredOutputGenerateApi(Resource):
class InstructionGenerateApi(Resource):
@console_ns.doc("generate_instruction")
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
@console_ns.expect(
console_ns.model(
"InstructionGenerateRequest",
{
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
"node_id": fields.String(description="Node ID for workflow context"),
"current": fields.String(description="Current instruction text"),
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
"instruction": fields.String(required=True, description="Instruction for generation"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@console_ns.response(200, "Instruction generated successfully")
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
@console_ns.response(402, "Provider quota exceeded")
@ -179,69 +195,79 @@ class InstructionGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
args = InstructionGeneratePayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("flow_id", type=str, required=True, default="", location="json")
.add_argument("node_id", type=str, required=False, default="", location="json")
.add_argument("current", type=str, required=False, default="", location="json")
.add_argument("language", type=str, required=False, default="javascript", location="json")
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("ideal_output", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args.language)), None
(p for p in providers if p.is_accept_language(args["language"])), None
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args.flow_id).first()
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
return {"error": f"app {args['flow_id']} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"error": f"workflow {args.flow_id} not found"}, 400
return {"error": f"workflow {args['flow_id']} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]
node = [node for node in nodes if node["id"] == args.node_id]
node = [node for node in nodes if node["id"] == args["node_id"]]
if len(node) == 0:
return {"error": f"node {args.node_id} not found"}, 400
return {"error": f"node {args['node_id']} not found"}, 400
node_type = node[0]["data"]["type"]
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
)
case _:
return {"error": f"invalid node type: {node_type}"}
if args.node_id == "" and args.current != "": # For legacy app without a workflow
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_tenant_id,
flow_id=args.flow_id,
current=args.current,
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
)
if args.node_id != "" and args.current != "": # For workflow node
if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_tenant_id,
flow_id=args.flow_id,
node_id=args.node_id,
current=args.current,
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
workflow_service=WorkflowService(),
)
return {"error": "incompatible parameters"}, 400
@ -259,15 +285,24 @@ class InstructionGenerateApi(Resource):
class InstructionGenerationTemplateApi(Resource):
@console_ns.doc("get_instruction_template")
@console_ns.doc(description="Get instruction generation template")
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
@console_ns.expect(
console_ns.model(
"InstructionTemplateRequest",
{
"instruction": fields.String(required=True, description="Template instruction"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@console_ns.response(200, "Template retrieved successfully")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
def post(self):
args = InstructionTemplatePayload.model_validate(console_ns.payload)
match args.type:
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args()
match args["type"]:
case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
@ -277,4 +312,4 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args.type}")
raise ValueError(f"Invalid type: {args['type']}")

View File

@ -1,9 +1,7 @@
import logging
from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
@ -35,67 +33,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("first_id", mode="before")
@classmethod
def empty_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
class FeedbackExportQuery(BaseModel):
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
@field_validator("has_comment", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool | None:
if isinstance(value, bool) or value is None:
return value
lowered = value.lower()
if lowered in {"true", "1", "yes", "on"}:
return True
if lowered in {"false", "0", "no", "off"}:
return False
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -220,7 +157,12 @@ class ChatMessageListApi(Resource):
@console_ns.doc("list_chat_messages")
@console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
)
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(404, "Conversation not found")
@login_required
@ -230,21 +172,27 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
conversation = (
db.session.query(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args.first_id:
if args["first_id"]:
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first()
)
@ -259,7 +207,7 @@ class ChatMessageListApi(Resource):
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(args.limit)
.limit(args["limit"])
.all()
)
else:
@ -267,12 +215,12 @@ class ChatMessageListApi(Resource):
db.session.query(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args.limit)
.limit(args["limit"])
.all()
)
# Initialize has_more based on whether we have a full page
if len(history_messages) == args.limit:
if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page
has_more = db.session.scalar(
@ -290,7 +238,7 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
@ -298,7 +246,15 @@ class MessageFeedbackApi(Resource):
@console_ns.doc("create_message_feedback")
@console_ns.doc(description="Create or update message feedback (like/dislike)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
@console_ns.expect(
console_ns.model(
"MessageFeedbackRequest",
{
"message_id": fields.String(required=True, description="Message ID"),
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
},
)
)
@console_ns.response(200, "Feedback updated successfully")
@console_ns.response(404, "Message not found")
@console_ns.response(403, "Insufficient permissions")
@ -309,9 +265,14 @@ class MessageFeedbackApi(Resource):
def post(self, app_model):
current_user, _ = current_account_with_tenant()
args = MessageFeedbackPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args()
message_id = str(args.message_id)
message_id = str(args["message_id"])
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
@ -320,21 +281,18 @@ class MessageFeedbackApi(Resource):
feedback = message.admin_feedback
if not args.rating and feedback:
if not args["rating"] and feedback:
db.session.delete(feedback)
elif args.rating and feedback:
feedback.rating = args.rating
elif not args.rating and not feedback:
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
rating_value = args.rating
if rating_value is None:
raise ValueError("rating is required to create feedback")
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating_value,
rating=args["rating"],
from_source="admin",
from_account_id=current_user.id,
)
@ -411,12 +369,24 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions}
# Shared parser for feedback export (used for both documentation and runtime parsing)
feedback_export_parser = (
console_ns.parser()
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
)
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
class MessageFeedbackExportApi(Resource):
@console_ns.doc("export_feedbacks")
@console_ns.doc(description="Export user feedback data for Google Sheets")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
@console_ns.expect(feedback_export_parser)
@console_ns.response(200, "Feedback data exported successfully")
@console_ns.response(400, "Invalid parameters")
@console_ns.response(500, "Internal server error")
@ -425,7 +395,7 @@ class MessageFeedbackExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = feedback_export_parser.parse_args()
# Import the service function
from services.feedback_service import FeedbackService
@ -433,12 +403,12 @@ class MessageFeedbackExportApi(Resource):
try:
export_data = FeedbackService.export_feedbacks(
app_id=app_model.id,
from_source=args.from_source,
rating=args.rating,
has_comment=args.has_comment,
start_date=args.start_date,
end_date=args.end_date,
format_type=args.format,
from_source=args.get("from_source"),
rating=args.get("rating"),
has_comment=args.get("has_comment"),
start_date=args.get("start_date"),
end_date=args.get("end_date"),
format_type=args.get("format", "csv"),
)
return export_data

View File

@ -1,9 +1,8 @@
from decimal import Decimal
import sqlalchemy as sa
from flask import abort, jsonify, request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from flask import abort, jsonify
from flask_restx import Resource, fields, reqparse
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -11,37 +10,21 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import convert_datetime_to_date
from libs.helper import DatetimeString, convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required
from models import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class StatisticTimeRangeQuery(BaseModel):
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def empty_string_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
StatisticTimeRangeQuery.__name__,
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
class DailyMessageStatistic(Resource):
@console_ns.doc("get_daily_message_statistics")
@console_ns.doc(description="Get daily message statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.response(
200,
"Daily message statistics retrieved successfully",
@ -54,7 +37,12 @@ class DailyMessageStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -69,7 +57,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -93,12 +81,19 @@ WHERE
return jsonify({"data": response_data})
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource):
@console_ns.doc("get_daily_conversation_statistics")
@console_ns.doc(description="Get daily conversation statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Daily conversation statistics retrieved successfully",
@ -111,7 +106,7 @@ class DailyConversationStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -126,7 +121,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -154,7 +149,7 @@ class DailyTerminalsStatistic(Resource):
@console_ns.doc("get_daily_terminals_statistics")
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Daily terminal statistics retrieved successfully",
@ -167,7 +162,7 @@ class DailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -182,7 +177,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -211,7 +206,7 @@ class DailyTokenCostStatistic(Resource):
@console_ns.doc("get_daily_token_cost_statistics")
@console_ns.doc(description="Get daily token cost statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Daily token cost statistics retrieved successfully",
@ -224,7 +219,7 @@ class DailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -240,7 +235,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -271,7 +266,7 @@ class AverageSessionInteractionStatistic(Resource):
@console_ns.doc("get_average_session_interaction_statistics")
@console_ns.doc(description="Get average session interaction statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Average session interaction statistics retrieved successfully",
@ -284,7 +279,7 @@ class AverageSessionInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("c.created_at")
sql_query = f"""SELECT
@ -307,7 +302,7 @@ FROM
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -347,7 +342,7 @@ class UserSatisfactionRateStatistic(Resource):
@console_ns.doc("get_user_satisfaction_rate_statistics")
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"User satisfaction rate statistics retrieved successfully",
@ -360,7 +355,7 @@ class UserSatisfactionRateStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("m.created_at")
sql_query = f"""SELECT
@ -379,7 +374,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -413,7 +408,7 @@ class AverageResponseTimeStatistic(Resource):
@console_ns.doc("get_average_response_time_statistics")
@console_ns.doc(description="Get average response time statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Average response time statistics retrieved successfully",
@ -426,7 +421,7 @@ class AverageResponseTimeStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -441,7 +436,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -470,7 +465,7 @@ class TokensPerSecondStatistic(Resource):
@console_ns.doc("get_tokens_per_second_statistics")
@console_ns.doc(description="Get tokens per second statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Tokens per second statistics retrieved successfully",
@ -482,7 +477,7 @@ class TokensPerSecondStatistic(Resource):
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -500,7 +495,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))

View File

@ -1,11 +1,10 @@
import json
import logging
from collections.abc import Sequence
from typing import Any
from typing import cast
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -50,7 +49,6 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__)
LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -109,104 +107,6 @@ if workflow_run_node_execution_model is None:
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
class SyncDraftWorkflowPayload(BaseModel):
graph: dict[str, Any]
features: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
class BaseWorkflowRunPayload(BaseModel):
files: list[dict[str, Any]] | None = None
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any] | None = None
query: str = ""
conversation_id: str | None = None
parent_message_id: str | None = None
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class IterationNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class LoopNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
query: str = ""
class PublishWorkflowPayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class ConvertToWorkflowPayload(BaseModel):
name: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str
class DraftWorkflowTriggerRunAllPayload(BaseModel):
node_ids: list[str]
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(SyncDraftWorkflowPayload)
reg(AdvancedChatWorkflowRunPayload)
reg(IterationNodeRunPayload)
reg(LoopNodeRunPayload)
reg(DraftWorkflowRunPayload)
reg(DraftWorkflowNodeRunPayload)
reg(PublishWorkflowPayload)
reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
# at the controller level rather than in the workflow logic. This would improve separation
# of concerns and make the code more maintainable.
@ -258,7 +158,18 @@ class DraftWorkflowApi(Resource):
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@console_ns.doc("sync_draft_workflow")
@console_ns.doc(description="Sync draft workflow configuration")
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
@console_ns.expect(
console_ns.model(
"SyncDraftWorkflowRequest",
{
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
"features": fields.Raw(required=True, description="Workflow features configuration"),
"hash": fields.String(description="Workflow hash for validation"),
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@console_ns.response(
200,
"Draft workflow synced successfully",
@ -282,23 +193,36 @@ class DraftWorkflowApi(Resource):
content_type = request.headers.get("Content-Type", "")
payload_data: dict[str, Any] | None = None
if "application/json" in content_type:
payload_data = request.get_json(silent=True)
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
parser = (
reqparse.RequestParser()
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("features", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type:
try:
payload_data = json.loads(request.data.decode("utf-8"))
data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
raise ValueError("graph or features is not a dict")
args = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
args = args_model.model_dump()
workflow_service = WorkflowService()
try:
@ -334,7 +258,17 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@console_ns.doc("run_advanced_chat_draft_workflow")
@console_ns.doc(description="Run draft workflow for advanced chat application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"AdvancedChatWorkflowRunRequest",
{
"query": fields.String(required=True, description="User query"),
"inputs": fields.Raw(description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
"conversation_id": fields.String(description="Conversation ID"),
},
)
)
@console_ns.response(200, "Workflow run started successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(403, "Permission denied")
@ -349,8 +283,16 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
args = args_model.model_dump(exclude_none=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, location="json")
.add_argument("query", type=str, required=True, location="json", default="")
.add_argument("files", type=list, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
@ -380,7 +322,15 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"IterationNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Iteration node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -394,7 +344,8 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
current_user, _ = current_account_with_tenant()
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
@ -418,7 +369,15 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_workflow_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"WorkflowIterationNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Workflow iteration node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -432,7 +391,8 @@ class WorkflowDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
current_user, _ = current_account_with_tenant()
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
@ -456,7 +416,15 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"LoopNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Loop node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -470,7 +438,8 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_loop(
@ -494,7 +463,15 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_workflow_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"WorkflowLoopNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Workflow loop node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -508,7 +485,8 @@ class WorkflowDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_loop(
@ -532,7 +510,15 @@ class DraftWorkflowRunApi(Resource):
@console_ns.doc("run_draft_workflow")
@console_ns.doc(description="Run draft workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"DraftWorkflowRunRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
},
)
)
@console_ns.response(200, "Draft workflow run started successfully")
@console_ns.response(403, "Permission denied")
@setup_required
@ -545,7 +531,12 @@ class DraftWorkflowRunApi(Resource):
Run draft workflow
"""
current_user, _ = current_account_with_tenant()
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
@ -597,7 +588,14 @@ class DraftWorkflowNodeRunApi(Resource):
@console_ns.doc("run_draft_workflow_node")
@console_ns.doc(description="Run draft workflow node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"DraftWorkflowNodeRunRequest",
{
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -612,10 +610,15 @@ class DraftWorkflowNodeRunApi(Resource):
Run draft workflow node
"""
current_user, _ = current_account_with_tenant()
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
args = args_model.model_dump(exclude_none=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("query", type=str, required=False, location="json", default="")
.add_argument("files", type=list, location="json", default=[])
)
args = parser.parse_args()
user_inputs = args_model.inputs
user_inputs = args.get("inputs")
if user_inputs is None:
raise ValueError("missing inputs")
@ -640,6 +643,13 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution
parser_publish = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource):
@console_ns.doc("get_published_workflow")
@ -664,7 +674,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None
return workflow
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
@console_ns.expect(parser_publish)
@setup_required
@login_required
@account_initialization_required
@ -676,7 +686,13 @@ class PublishedWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
args = parser_publish.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
workflow_service = WorkflowService()
with Session(db.engine) as session:
@ -725,6 +741,9 @@ class DefaultBlockConfigsApi(Resource):
return workflow_service.get_default_block_configs()
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource):
@console_ns.doc("get_default_block_config")
@ -732,7 +751,7 @@ class DefaultBlockConfigApi(Resource):
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@console_ns.response(200, "Default block configuration retrieved successfully")
@console_ns.response(404, "Block type not found")
@console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
@console_ns.expect(parser_block)
@setup_required
@login_required
@account_initialization_required
@ -742,12 +761,14 @@ class DefaultBlockConfigApi(Resource):
"""
Get default block config
"""
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser_block.parse_args()
q = args.get("q")
filters = None
if args.q:
if q:
try:
filters = json.loads(args.q)
filters = json.loads(args.get("q", ""))
except json.JSONDecodeError:
raise ValueError("Invalid filters")
@ -756,9 +777,18 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
parser_convert = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource):
@console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
@console_ns.expect(parser_convert)
@console_ns.doc("convert_to_workflow")
@console_ns.doc(description="Convert application to workflow mode")
@console_ns.doc(params={"app_id": "Application ID"})
@ -778,8 +808,10 @@ class ConvertToWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
if request.data:
args = parser_convert.parse_args()
else:
args = {}
# convert to workflow mode
workflow_service = WorkflowService()
@ -791,9 +823,18 @@ class ConvertToWorkflowApi(Resource):
}
parser_workflows = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
@console_ns.expect(parser_workflows)
@console_ns.doc("get_all_published_workflows")
@console_ns.doc(description="Get all published workflows for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@ -810,15 +851,16 @@ class PublishedAllWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args.page
limit = args.limit
user_id = args.user_id
named_only = args.named_only
args = parser_workflows.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
if user_id:
if user_id != current_user.id:
raise Forbidden()
user_id = cast(str, user_id)
workflow_service = WorkflowService()
with Session(db.engine) as session:
@ -844,7 +886,15 @@ class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id")
@console_ns.doc(description="Update workflow by ID")
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
@console_ns.expect(
console_ns.model(
"UpdateWorkflowRequest",
{
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@console_ns.response(200, "Workflow updated successfully", workflow_model)
@console_ns.response(404, "Workflow not found")
@console_ns.response(403, "Permission denied")
@ -859,14 +909,25 @@ class WorkflowByIdApi(Resource):
Update workflow attributes
"""
current_user, _ = current_account_with_tenant()
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
# Prepare update data
update_data = {}
if args.marked_name is not None:
update_data["marked_name"] = args.marked_name
if args.marked_comment is not None:
update_data["marked_comment"] = args.marked_comment
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]
if not update_data:
return {"message": "No valid fields to update"}, 400
@ -979,8 +1040,11 @@ class DraftWorkflowTriggerRunApi(Resource):
Poll for trigger events and execute full workflow when event arrives
"""
current_user, _ = current_account_with_tenant()
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
node_id = args.node_id
parser = reqparse.RequestParser().add_argument(
"node_id", type=str, required=True, location="json", nullable=False
)
args = parser.parse_args()
node_id = args["node_id"]
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
@ -1108,7 +1172,14 @@ class DraftWorkflowTriggerRunAllApi(Resource):
@console_ns.doc("draft_workflow_trigger_run_all")
@console_ns.doc(description="Full workflow debug when the start node is a trigger")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
@console_ns.expect(
console_ns.model(
"DraftWorkflowTriggerRunAllRequest",
{
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
},
)
)
@console_ns.response(200, "Workflow executed successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(500, "Internal server error")
@ -1123,8 +1194,11 @@ class DraftWorkflowTriggerRunAllApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
node_ids = args.node_ids
parser = reqparse.RequestParser().add_argument(
"node_ids", type=list, required=True, location="json", nullable=False
)
args = parser.parse_args()
node_ids = args["node_ids"]
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:

View File

@ -1,9 +1,6 @@
from datetime import datetime
from dateutil.parser import isoparse
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -17,48 +14,6 @@ from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
status: WorkflowExecutionStatus | None = Field(
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
)
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
created_by_account: str | None = Field(default=None, description="Filter by account")
detail: bool = Field(default=False, description="Whether to return detailed logs")
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
@field_validator("created_at__before", "created_at__after", mode="before")
@classmethod
def parse_datetime(cls, value: str | None) -> datetime | None:
if value in (None, ""):
return None
return isoparse(value) # type: ignore
@field_validator("detail", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
lowered = value.lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
raise ValueError("Invalid boolean value for detail")
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
@ -68,7 +23,19 @@ class WorkflowAppLogApi(Resource):
@console_ns.doc("get_workflow_app_logs")
@console_ns.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.doc(
params={
"keyword": "Search keyword for filtering logs",
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
"created_at__before": "Filter logs created before this timestamp",
"created_at__after": "Filter logs created after this timestamp",
"created_by_end_user_session_id": "Filter by end user session ID",
"created_by_account": "Filter by account",
"detail": "Whether to return detailed logs",
"page": "Page number (1-99999)",
"limit": "Number of items per page (1-100)",
}
)
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@setup_required
@login_required
@ -79,7 +46,44 @@ class WorkflowAppLogApi(Resource):
"""
Get workflow app logs
"""
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("detail", type=bool, location="args", required=False, default=False)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()

View File

@ -1,8 +1,7 @@
from typing import Literal, cast
from typing import cast
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -93,51 +92,70 @@ workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _parse_workflow_run_list_args():
"""
Parse common arguments for workflow run list endpoints.
class WorkflowRunListQuery(BaseModel):
last_id: str | None = Field(default=None, description="Last run ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
return parser.parse_args()
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
@field_validator("last_id")
@classmethod
def validate_last_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class WorkflowRunCountQuery(BaseModel):
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
)
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
)
@field_validator("time_range")
@classmethod
def validate_time_range(cls, value: str | None) -> str | None:
if value is None:
return value
return time_duration(value)
console_ns.schema_model(
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
WorkflowRunCountQuery.__name__,
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
@ -152,7 +170,6 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
@setup_required
@login_required
@ -163,13 +180,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
"""
Get advanced chat app workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args = _parse_workflow_run_list_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
@ -201,7 +217,6 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -211,13 +226,12 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
"""
Get advanced chat workflow runs count statistics
"""
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args = _parse_workflow_run_count_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
@ -245,7 +259,6 @@ class WorkflowRunListApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -255,13 +268,12 @@ class WorkflowRunListApi(Resource):
"""
Get workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args = _parse_workflow_run_list_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
@ -293,7 +305,6 @@ class WorkflowRunCountApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -303,13 +314,12 @@ class WorkflowRunCountApi(Resource):
"""
Get workflow runs count statistics
"""
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args = _parse_workflow_run_count_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)

View File

@ -1,6 +1,5 @@
from flask import abort, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from flask import abort, jsonify
from flask_restx import Resource, reqparse
from sqlalchemy.orm import sessionmaker
from controllers.console import console_ns
@ -8,31 +7,12 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowStatisticQuery(BaseModel):
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
WorkflowStatisticQuery.__name__,
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource):
@ -44,7 +24,9 @@ class WorkflowDailyRunsStatistic(Resource):
@console_ns.doc("get_workflow_daily_runs_statistic")
@console_ns.doc(description="Get workflow daily runs statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily runs statistics retrieved successfully")
@get_app_model
@setup_required
@ -53,12 +35,17 @@ class WorkflowDailyRunsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -84,7 +71,9 @@ class WorkflowDailyTerminalsStatistic(Resource):
@console_ns.doc("get_workflow_daily_terminals_statistic")
@console_ns.doc(description="Get workflow daily terminals statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
@get_app_model
@setup_required
@ -93,12 +82,17 @@ class WorkflowDailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -124,7 +118,9 @@ class WorkflowDailyTokenCostStatistic(Resource):
@console_ns.doc("get_workflow_daily_token_cost_statistic")
@console_ns.doc(description="Get workflow daily token cost statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
@get_app_model
@setup_required
@ -133,12 +129,17 @@ class WorkflowDailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -164,7 +165,9 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@console_ns.doc("get_workflow_average_app_interaction_statistic")
@console_ns.doc(description="Get workflow average app interaction statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
@setup_required
@login_required
@ -173,12 +176,17 @@ class WorkflowAverageAppInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))

View File

@ -58,7 +58,7 @@ class VersionApi(Resource):
response = httpx.get(
check_update_url,
params={"current_version": args["current_version"]},
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
timeout=httpx.Timeout(connect=3, read=10),
)
except Exception as error:
logger.warning("Check update version error: %s.", str(error))

View File

@ -174,25 +174,63 @@ class CheckEmailUniquePayload(BaseModel):
return email(value)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AccountInitPayload)
reg(AccountNamePayload)
reg(AccountAvatarPayload)
reg(AccountInterfaceLanguagePayload)
reg(AccountInterfaceThemePayload)
reg(AccountTimezonePayload)
reg(AccountPasswordPayload)
reg(AccountDeletePayload)
reg(AccountDeletionFeedbackPayload)
reg(EducationActivatePayload)
reg(EducationAutocompleteQuery)
reg(ChangeEmailSendPayload)
reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload)
console_ns.schema_model(
AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
AccountInterfaceLanguagePayload.__name__,
AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountInterfaceThemePayload.__name__,
AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountTimezonePayload.__name__,
AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountPasswordPayload.__name__,
AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletePayload.__name__,
AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
AccountDeletionFeedbackPayload.__name__,
AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationActivatePayload.__name__,
EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EducationAutocompleteQuery.__name__,
EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailSendPayload.__name__,
ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailValidityPayload.__name__,
ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChangeEmailResetPayload.__name__,
ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
CheckEmailUniquePayload.__name__,
CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/account/init")

View File

@ -1,8 +1,4 @@
from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, reqparse
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
@ -11,49 +7,21 @@ from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EndpointCreatePayload(BaseModel):
plugin_unique_identifier: str
settings: dict[str, Any]
name: str = Field(min_length=1)
class EndpointIdPayload(BaseModel):
endpoint_id: str
class EndpointUpdatePayload(EndpointIdPayload):
settings: dict[str, Any]
name: str = Field(min_length=1)
class EndpointListQuery(BaseModel):
page: int = Field(ge=1)
page_size: int = Field(gt=0)
class EndpointListForPluginQuery(EndpointListQuery):
plugin_id: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(EndpointCreatePayload)
reg(EndpointIdPayload)
reg(EndpointUpdatePayload)
reg(EndpointListQuery)
reg(EndpointListForPluginQuery)
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@console_ns.doc("create_endpoint")
@console_ns.doc(description="Create a new plugin endpoint")
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
@console_ns.expect(
console_ns.model(
"EndpointCreateRequest",
{
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
"settings": fields.Raw(required=True, description="Endpoint settings"),
"name": fields.String(required=True, description="Endpoint name"),
},
)
)
@console_ns.response(
200,
"Endpoint created successfully",
@ -67,16 +35,26 @@ class EndpointCreateApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
args = EndpointCreatePayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True)
.add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args()
plugin_unique_identifier = args["plugin_unique_identifier"]
settings = args["settings"]
name = args["name"]
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=tenant_id,
user_id=user.id,
plugin_unique_identifier=args.plugin_unique_identifier,
name=args.name,
settings=args.settings,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
settings=settings,
)
}
except PluginPermissionDeniedError as e:
@ -87,7 +65,11 @@ class EndpointCreateApi(Resource):
class EndpointListApi(Resource):
@console_ns.doc("list_endpoints")
@console_ns.doc(description="List plugin endpoints with pagination")
@console_ns.expect(console_ns.models[EndpointListQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
)
@console_ns.response(
200,
"Success",
@ -101,10 +83,15 @@ class EndpointListApi(Resource):
def get(self):
user, tenant_id = current_account_with_tenant()
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args()
page = args.page
page_size = args.page_size
page = args["page"]
page_size = args["page_size"]
return jsonable_encoder(
{
@ -122,7 +109,12 @@ class EndpointListApi(Resource):
class EndpointListForSinglePluginApi(Resource):
@console_ns.doc("list_plugin_endpoints")
@console_ns.doc(description="List endpoints for a specific plugin")
@console_ns.expect(console_ns.models[EndpointListForPluginQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
)
@console_ns.response(
200,
"Success",
@ -136,11 +128,17 @@ class EndpointListForSinglePluginApi(Resource):
def get(self):
user, tenant_id = current_account_with_tenant()
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
.add_argument("plugin_id", type=str, required=True, location="args")
)
args = parser.parse_args()
page = args.page
page_size = args.page_size
plugin_id = args.plugin_id
page = args["page"]
page_size = args["page_size"]
plugin_id = args["plugin_id"]
return jsonable_encoder(
{
@ -159,7 +157,11 @@ class EndpointListForSinglePluginApi(Resource):
class EndpointDeleteApi(Resource):
@console_ns.doc("delete_endpoint")
@console_ns.doc(description="Delete a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
@console_ns.expect(
console_ns.model(
"EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
)
)
@console_ns.response(
200,
"Endpoint deleted successfully",
@ -173,12 +175,13 @@ class EndpointDeleteApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
args = EndpointIdPayload.model_validate(console_ns.payload)
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
return {
"success": EndpointService.delete_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@ -186,7 +189,16 @@ class EndpointDeleteApi(Resource):
class EndpointUpdateApi(Resource):
@console_ns.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
@console_ns.expect(
console_ns.model(
"EndpointUpdateRequest",
{
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
"settings": fields.Raw(required=True, description="Updated settings"),
"name": fields.String(required=True, description="Updated name"),
},
)
)
@console_ns.response(
200,
"Endpoint updated successfully",
@ -200,15 +212,25 @@ class EndpointUpdateApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
args = EndpointUpdatePayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("endpoint_id", type=str, required=True)
.add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
settings = args["settings"]
name = args["name"]
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=args.endpoint_id,
name=args.name,
settings=args.settings,
endpoint_id=endpoint_id,
name=name,
settings=settings,
)
}
@ -217,7 +239,11 @@ class EndpointUpdateApi(Resource):
class EndpointEnableApi(Resource):
@console_ns.doc("enable_endpoint")
@console_ns.doc(description="Enable a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
@console_ns.expect(
console_ns.model(
"EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
)
)
@console_ns.response(
200,
"Endpoint enabled successfully",
@ -231,12 +257,13 @@ class EndpointEnableApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
args = EndpointIdPayload.model_validate(console_ns.payload)
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
return {
"success": EndpointService.enable_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@ -244,7 +271,11 @@ class EndpointEnableApi(Resource):
class EndpointDisableApi(Resource):
@console_ns.doc("disable_endpoint")
@console_ns.doc(description="Disable a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
@console_ns.expect(
console_ns.model(
"EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
)
)
@console_ns.response(
200,
"Endpoint disabled successfully",
@ -258,10 +289,11 @@ class EndpointDisableApi(Resource):
def post(self):
user, tenant_id = current_account_with_tenant()
args = EndpointIdPayload.model_validate(console_ns.payload)
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
return {
"success": EndpointService.disable_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}

View File

@ -58,15 +58,26 @@ class OwnerTransferPayload(BaseModel):
token: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(MemberInvitePayload)
reg(MemberRoleUpdatePayload)
reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload)
console_ns.schema_model(
MemberInvitePayload.__name__,
MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
MemberRoleUpdatePayload.__name__,
MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferEmailPayload.__name__,
OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferCheckPayload.__name__,
OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
OwnerTransferPayload.__name__,
OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/workspaces/current/members")

View File

@ -75,18 +75,44 @@ class ParserPreferredProviderType(BaseModel):
preferred_provider_type: Literal["system", "custom"]
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserCredentialId.__name__,
ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
reg(ParserModelList)
reg(ParserCredentialId)
reg(ParserCredentialCreate)
reg(ParserCredentialUpdate)
reg(ParserCredentialDelete)
reg(ParserCredentialSwitch)
reg(ParserCredentialValidate)
reg(ParserPreferredProviderType)
console_ns.schema_model(
ParserCredentialCreate.__name__,
ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialUpdate.__name__,
ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialDelete.__name__,
ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialSwitch.__name__,
ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserCredentialValidate.__name__,
ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserPreferredProviderType.__name__,
ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/workspaces/current/model-providers")

View File

@ -32,11 +32,25 @@ class ParserPostDefault(BaseModel):
model_settings: list[Inner]
console_ns.schema_model(
ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class ParserDeleteModels(BaseModel):
model: str
model_type: ModelType
console_ns.schema_model(
ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class LoadBalancingPayload(BaseModel):
configs: list[dict[str, Any]] | None = None
enabled: bool | None = None
@ -105,19 +119,33 @@ class ParserParameter(BaseModel):
model: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserGetCredentials.__name__,
ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
reg(ParserGetDefault)
reg(ParserPostDefault)
reg(ParserDeleteModels)
reg(ParserPostModels)
reg(ParserGetCredentials)
reg(ParserCreateCredential)
reg(ParserUpdateCredential)
reg(ParserDeleteCredential)
reg(ParserParameter)
console_ns.schema_model(
ParserCreateCredential.__name__,
ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserUpdateCredential.__name__,
ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserDeleteCredential.__name__,
ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/default-model")

View File

@ -22,10 +22,6 @@ from services.plugin.plugin_service import PluginService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@ -50,7 +46,9 @@ class ParserList(BaseModel):
page_size: int = Field(default=256)
reg(ParserList)
console_ns.schema_model(
ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/plugin/list")
@ -74,6 +72,11 @@ class ParserLatest(BaseModel):
plugin_ids: list[str]
console_ns.schema_model(
ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class ParserIcon(BaseModel):
tenant_id: str
filename: str
@ -170,22 +173,72 @@ class ParserReadme(BaseModel):
language: str = Field(default="en-US")
reg(ParserLatest)
reg(ParserIcon)
reg(ParserAsset)
reg(ParserGithubUpload)
reg(ParserPluginIdentifiers)
reg(ParserGithubInstall)
reg(ParserPluginIdentifierQuery)
reg(ParserTasks)
reg(ParserMarketplaceUpgrade)
reg(ParserGithubUpgrade)
reg(ParserUninstall)
reg(ParserPermissionChange)
reg(ParserDynamicOptions)
reg(ParserPreferencesChange)
reg(ParserExcludePlugin)
reg(ParserReadme)
console_ns.schema_model(
ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPluginIdentifiers.__name__,
ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPluginIdentifierQuery.__name__,
ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserMarketplaceUpgrade.__name__,
ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ParserPermissionChange.__name__,
ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserDynamicOptions.__name__,
ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserPreferencesChange.__name__,
ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserExcludePlugin.__name__,
ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/workspaces/current/plugin/list/latest-versions")

View File

@ -54,14 +54,25 @@ class WorkspaceInfoPayload(BaseModel):
name: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
SwitchWorkspacePayload.__name__,
SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkspaceCustomConfigPayload.__name__,
WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkspaceInfoPayload.__name__,
WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
reg(WorkspaceListQuery)
reg(SwitchWorkspacePayload)
reg(WorkspaceCustomConfigPayload)
reg(WorkspaceInfoPayload)
provider_fields = {
"provider_name": fields.String,

View File

@ -316,16 +316,18 @@ def validate_and_get_api_token(scope: str | None = None):
ApiToken.type == scope,
)
.values(last_used_at=current_time)
.returning(ApiToken)
)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
result = session.execute(update_stmt)
api_token = session.scalar(stmt)
if hasattr(result, "rowcount") and result.rowcount > 0:
session.commit()
api_token = result.scalar_one_or_none()
if not api_token:
raise Unauthorized("Access token is invalid")
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()
return api_token

View File

@ -1,4 +1,3 @@
import logging
import time
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
@ -56,7 +55,6 @@ from models import Account, EndUser
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
logger = logging.getLogger(__name__)
@dataclass(slots=True)
@ -291,30 +289,26 @@ class WorkflowResponseConverter:
),
)
try:
if event.node_type == NodeType.TOOL:
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=ToolProviderType(event.provider_type),
provider_id=event.provider_id,
)
elif event.node_type == NodeType.DATASOURCE:
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
)
elif event.node_type == NodeType.TRIGGER_PLUGIN:
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
except Exception:
# metadata fetch may fail, for example, the plugin daemon is down or plugin is uninstalled.
logger.warning("failed to fetch icon for %s", event.provider_id)
if event.node_type == NodeType.TOOL:
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=ToolProviderType(event.provider_type),
provider_id=event.provider_id,
)
elif event.node_type == NodeType.DATASOURCE:
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
)
elif event.node_type == NodeType.TRIGGER_PLUGIN:
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
return response

View File

@ -156,82 +156,78 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
with db.session.begin():
if not conversation:
conversation = Conversation(
app_id=app_config.app_id,
app_model_config_id=app_model_config_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value,
name=conversation_name,
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
)
db.session.add(conversation)
db.session.flush()
db.session.refresh(conversation)
else:
conversation.updated_at = naive_utc_now()
message = Message(
if not conversation:
conversation = Conversation(
app_id=app_config.app_id,
app_model_config_id=app_model_config_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
mode=app_config.app_mode.value,
name=conversation_name,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(message)
db.session.flush()
db.session.refresh(message)
message_files = []
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)
message_files.append(message_file)
if message_files:
db.session.add_all(message_files)
db.session.add(conversation)
db.session.commit()
db.session.refresh(conversation)
else:
conversation.updated_at = naive_utc_now()
db.session.commit()
message = Message(
app_id=app_config.app_id,
model_provider=model_provider,
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,
message_price_unit=0,
answer="",
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(message)
db.session.commit()
db.session.refresh(message)
for file in application_generate_entity.files:
message_file = MessageFile(
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
created_by=account_id or end_user_id or "",
)
db.session.add(message_file)
db.session.commit()
return conversation, message
def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:

View File

@ -4,15 +4,15 @@ from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
class InvokeFrom(StrEnum):
"""
@ -275,8 +275,10 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
start_node_id: str | None = None
# Import TraceQueueManager at runtime to resolve forward references
from core.ops.ops_trace_manager import TraceQueueManager
# Rebuild models that use forward references
AppGenerateEntity.model_rebuild()
EasyUIBasedAppGenerateEntity.model_rebuild()
ConversationAppGenerateEntity.model_rebuild()

View File

@ -29,7 +29,6 @@ class SimpleModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: list[ModelType]
@ -43,7 +42,6 @@ class SimpleModelProviderEntity(BaseModel):
provider=provider_entity.provider,
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_small_dark=provider_entity.icon_small_dark,
icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types,
)

View File

@ -99,7 +99,6 @@ class SimpleProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = []
@ -125,6 +124,7 @@ class ProviderEntity(BaseModel):
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large_dark: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: Sequence[ModelType]

View File

@ -300,14 +300,6 @@ class ModelProviderFactory:
file_name = provider_schema.icon_small.zh_Hans
else:
file_name = provider_schema.icon_small.en_US
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_small_dark.zh_Hans
else:
file_name = provider_schema.icon_small_dark.en_US
else:
if not provider_schema.icon_large:
raise ValueError(f"Provider {provider} does not have large icon.")

View File

@ -58,39 +58,11 @@ class OceanBaseVector(BaseVector):
password=self._config.password,
db_name=self._config.database,
)
self._fields: list[str] = [] # List of fields in the collection
if self._client.check_table_exists(collection_name):
self._load_collection_fields()
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
def get_type(self) -> str:
return VectorType.OCEANBASE
def _load_collection_fields(self):
"""
Load collection fields from the database table.
This method populates the _fields list with column names from the table.
"""
try:
if self._collection_name in self._client.metadata_obj.tables:
table = self._client.metadata_obj.tables[self._collection_name]
# Store all column names except 'id' (primary key)
self._fields = [column.name for column in table.columns if column.name != "id"]
logger.debug("Loaded fields for collection '%s': %s", self._collection_name, self._fields)
else:
logger.warning("Collection '%s' not found in metadata", self._collection_name)
except Exception as e:
logger.warning("Failed to load collection fields for '%s': %s", self._collection_name, str(e))
def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
:param field: Field name to check
:return: True if field exists, False otherwise
"""
return field in self._fields
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._vec_dim = len(embeddings[0])
self._create_collection()
@ -179,7 +151,6 @@ class OceanBaseVector(BaseVector):
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
self._client.refresh_metadata([self._collection_name])
self._load_collection_fields()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _check_hybrid_search_support(self) -> bool:
@ -206,134 +177,42 @@ class OceanBaseVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents)
for id, doc, emb in zip(ids, documents, embeddings):
try:
self._client.insert(
table_name=self._collection_name,
data={
"id": id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
},
)
except Exception as e:
logger.exception(
"Failed to insert document with id '%s' in collection '%s'",
id,
self._collection_name,
)
raise Exception(f"Failed to insert document with id '{id}'") from e
self._client.insert(
table_name=self._collection_name,
data={
"id": id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
},
)
def text_exists(self, id: str) -> bool:
try:
cur = self._client.get(table_name=self._collection_name, ids=id)
return bool(cur.rowcount != 0)
except Exception as e:
logger.exception(
"Failed to check if text exists with id '%s' in collection '%s'",
id,
self._collection_name,
)
raise Exception(f"Failed to check text existence for id '{id}'") from e
cur = self._client.get(table_name=self._collection_name, ids=id)
return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]):
if not ids:
return
try:
self._client.delete(table_name=self._collection_name, ids=ids)
logger.debug("Deleted %d documents from collection '%s'", len(ids), self._collection_name)
except Exception as e:
logger.exception(
"Failed to delete %d documents from collection '%s'",
len(ids),
self._collection_name,
)
raise Exception(f"Failed to delete documents from collection '{self._collection_name}'") from e
self._client.delete(table_name=self._collection_name, ids=ids)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
try:
import re
from sqlalchemy import text
from sqlalchemy import text
# Validate key to prevent injection in JSON path
if not re.match(r"^[a-zA-Z0-9_.]+$", key):
raise ValueError(f"Invalid characters in metadata key: {key}")
# Use parameterized query to prevent SQL injection
sql = text(f"SELECT id FROM `{self._collection_name}` WHERE metadata->>'$.{key}' = :value")
with self._client.engine.connect() as conn:
result = conn.execute(sql, {"value": value})
ids = [row[0] for row in result]
logger.debug(
"Found %d documents with metadata field '%s'='%s' in collection '%s'",
len(ids),
key,
value,
self._collection_name,
)
return ids
except Exception as e:
logger.exception(
"Failed to get IDs by metadata field '%s'='%s' in collection '%s'",
key,
value,
self._collection_name,
)
raise Exception(f"Failed to query documents by metadata field '{key}'") from e
cur = self._client.get(
table_name=self._collection_name,
ids=None,
where_clause=[text(f"metadata->>'$.{key}' = '{value}'")],
output_column_name=["id"],
)
return [row[0] for row in cur]
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_ids(ids)
else:
logger.debug("No documents found to delete with metadata field '%s'='%s'", key, value)
def _process_search_results(
self, results: list[tuple], score_threshold: float = 0.0, score_key: str = "score"
) -> list[Document]:
"""
Common method to process search results
:param results: Search results as list of tuples (text, metadata, score)
:param score_threshold: Score threshold for filtering
:param score_key: Key name for score in metadata
:return: List of documents
"""
docs = []
for row in results:
text, metadata_str, score = row[0], row[1], row[2]
# Parse metadata JSON
try:
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str
except json.JSONDecodeError:
logger.warning("Invalid JSON metadata: %s", metadata_str)
metadata = {}
# Add score to metadata
metadata[score_key] = score
# Filter by score threshold
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
self.delete_by_ids(ids)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
if not self._hybrid_search_enabled:
logger.warning(
"Full-text search is disabled: set OCEANBASE_ENABLE_HYBRID_SEARCH=true (requires OceanBase >= 4.3.5.1)."
)
return []
if not self.field_exists("text"):
logger.warning(
"Full-text search unavailable: collection '%s' missing 'text' field; "
"recreate the collection after enabling OCEANBASE_ENABLE_HYBRID_SEARCH to add fulltext index.",
self._collection_name,
)
return []
try:
@ -341,24 +220,13 @@ class OceanBaseVector(BaseVector):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# Build parameterized query to prevent SQL injection
from sqlalchemy import text
document_ids_filter = kwargs.get("document_ids_filter")
params = {"query": query}
where_clause = ""
if document_ids_filter:
# Create parameterized placeholders for document IDs
placeholders = ", ".join(f":doc_id_{i}" for i in range(len(document_ids_filter)))
where_clause = f" AND metadata->>'$.document_id' IN ({placeholders})"
# Add document IDs to parameters
for i, doc_id in enumerate(document_ids_filter):
params[f"doc_id_{i}"] = doc_id
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'$.document_id' IN ({document_ids})"
full_sql = f"""SELECT text, metadata, MATCH (text) AGAINST (:query) AS score
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score
FROM {self._collection_name}
WHERE MATCH (text) AGAINST (:query) > 0
{where_clause}
@ -367,45 +235,41 @@ class OceanBaseVector(BaseVector):
with self._client.engine.connect() as conn:
with conn.begin():
result = conn.execute(text(full_sql), params)
from sqlalchemy import text
result = conn.execute(text(full_sql), {"query": query})
rows = result.fetchall()
return self._process_search_results(rows, score_threshold=score_threshold)
docs = []
for row in rows:
metadata_str, _text, score = row
try:
metadata = json.loads(metadata_str)
except json.JSONDecodeError:
logger.warning("Invalid JSON metadata: %s", metadata_str)
metadata = {}
metadata["score"] = score
docs.append(Document(page_content=_text, metadata=metadata))
return docs
except Exception as e:
logger.exception(
"Failed to perform full-text search on collection '%s' with query '%s'",
self._collection_name,
query,
)
raise Exception(f"Full-text search failed for collection '{self._collection_name}'") from e
logger.warning("Failed to fulltext search: %s.", str(e))
return []
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from sqlalchemy import text
document_ids_filter = kwargs.get("document_ids_filter")
_where_clause = None
if document_ids_filter:
# Validate document IDs to prevent SQL injection
# Document IDs should be alphanumeric with hyphens and underscores
import re
for doc_id in document_ids_filter:
if not isinstance(doc_id, str) or not re.match(r"^[a-zA-Z0-9_-]+$", doc_id):
raise ValueError(f"Invalid document ID format: {doc_id}")
# Safe to use in query after validation
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
from sqlalchemy import text
_where_clause = [text(where_clause)]
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search)
self._hnsw_ef_search = ef_search
topk = kwargs.get("top_k", 10)
try:
score_threshold = float(val) if (val := kwargs.get("score_threshold")) is not None else 0.0
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid score_threshold parameter: {e}") from e
try:
cur = self._client.ann_search(
table_name=self._collection_name,
@ -418,27 +282,21 @@ class OceanBaseVector(BaseVector):
where_clause=_where_clause,
)
except Exception as e:
logger.exception(
"Failed to perform vector search on collection '%s'",
self._collection_name,
raise Exception("Failed to search by vector. ", e)
docs = []
for _text, metadata, distance in cur:
metadata = json.loads(metadata)
metadata["score"] = 1 - distance / math.sqrt(2)
docs.append(
Document(
page_content=_text,
metadata=metadata,
)
)
raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e
# Convert distance to score and prepare results for processing
results = []
for _text, metadata_str, distance in cur:
score = 1 - distance / math.sqrt(2)
results.append((_text, metadata_str, score))
return self._process_search_results(results, score_threshold=score_threshold)
return docs
def delete(self):
try:
self._client.drop_table_if_exist(self._collection_name)
logger.debug("Dropped collection '%s'", self._collection_name)
except Exception as e:
logger.exception("Failed to delete collection '%s'", self._collection_name)
raise Exception(f"Failed to delete collection '{self._collection_name}'") from e
self._client.drop_table_if_exist(self._collection_name)
class OceanBaseVectorFactory(AbstractVectorFactory):

View File

@ -54,8 +54,6 @@ class ToolProviderApiEntity(BaseModel):
configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
# Workflow
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")
@field_validator("tools", mode="before")
@classmethod
@ -89,8 +87,6 @@ class ToolProviderApiEntity(BaseModel):
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
elif self.type == ToolProviderType.WORKFLOW:
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
return {
"id": self.id,
"author": self.author,

View File

@ -203,7 +203,7 @@ class WorkflowTool(Tool):
Resolve user object in both HTTP and worker contexts.
In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser).
In worker context: load Account(knowledge pipeline) or EndUser(trigger) from database by user_id.
In worker context: load Account from database by user_id (only returns Account, never EndUser).
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
@ -224,28 +224,24 @@ class WorkflowTool(Tool):
logger.warning("Failed to resolve user from request context: %s", e)
return None
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
def _resolve_user_from_database(self, user_id: str) -> Account | None:
"""
Resolve user from database (worker/Celery context).
"""
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if not user:
return None
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = db.session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if user:
user.current_tenant = tenant
return user
user.current_tenant = tenant
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = db.session.scalar(end_user_stmt)
if end_user:
return end_user
return None
return user
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""

View File

@ -1,11 +1,7 @@
import importlib
import logging
import operator
import pkgutil
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from types import MappingProxyType
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from uuid import uuid4
@ -138,34 +134,6 @@ class Node(Generic[NodeDataT]):
cls._node_data_type = node_data_type
# Skip base class itself
if cls is Node:
return
# Only register production node implementations defined under core.workflow.nodes.*
# This prevents test helper subclasses from polluting the global registry and
# accidentally overriding real node types (e.g., a test Answer node).
module_name = getattr(cls, "__module__", "")
# Only register concrete subclasses that define node_type and version()
node_type = cls.node_type
version = cls.version()
bucket = Node._registry.setdefault(node_type, {})
if module_name.startswith("core.workflow.nodes."):
# Production node definitions take precedence and may override
bucket[version] = cls # type: ignore[index]
else:
# External/test subclasses may register but must not override production
bucket.setdefault(version, cls) # type: ignore[index]
# Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic
version_keys = [v for v in bucket if v != "latest"]
numeric_pairs: list[tuple[str, int]] = []
for v in version_keys:
numeric_pairs.append((v, int(v)))
if numeric_pairs:
latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0]
else:
latest_key = max(version_keys) if version_keys else version
bucket["latest"] = bucket[latest_key]
@classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
"""
@ -197,9 +165,6 @@ class Node(Generic[NodeDataT]):
return None
# Global registry populated via __init_subclass__
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
def __init__(
self,
id: str,
@ -275,23 +240,23 @@ class Node(Generic[NodeDataT]):
from core.workflow.nodes.tool.tool_node import ToolNode
if isinstance(self, ToolNode):
start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.node_data, "provider_type", "")
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
if isinstance(self, DatasourceNode):
plugin_id = getattr(self.node_data, "plugin_id", "")
provider_name = getattr(self.node_data, "provider_name", "")
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "")
provider_name = getattr(self.get_base_node_data(), "provider_name", "")
start_event.provider_id = f"{plugin_id}/{provider_name}"
start_event.provider_type = getattr(self.node_data, "provider_type", "")
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
if isinstance(self, TriggerEventNode):
start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.node_data, "provider_type", "")
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
from typing import cast
@ -300,7 +265,7 @@ class Node(Generic[NodeDataT]):
if isinstance(self, AgentNode):
start_event.agent_strategy = AgentNodeStrategyInit(
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
icon=self.agent_strategy_icon,
)
@ -430,29 +395,6 @@ class Node(Generic[NodeDataT]):
# in `api/core/workflow/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
Import all modules under core.workflow.nodes so subclasses register themselves on import.
Then we return a readonly view of the registry to avoid accidental mutation.
"""
# Import all node modules to ensure they are loaded (thus registered)
import core.workflow.nodes as _nodes_pkg
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
# Avoid importing modules that depend on the registry to prevent circular imports
# e.g. node_factory imports node_mapping which builds the mapping here.
if _modname in {
"core.workflow.nodes.node_factory",
"core.workflow.nodes.node_mapping",
}:
continue
importlib.import_module(_modname)
# Return a readonly view so callers can't mutate the registry by accident
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
@property
def retry(self) -> bool:
return False
@ -477,6 +419,10 @@ class Node(Generic[NodeDataT]):
"""Get the default values dictionary for this node."""
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
"""Get the BaseNodeData object for this node."""
return self._node_data
# Public interface properties that delegate to abstract methods
@property
def error_strategy(self) -> ErrorStrategy | None:
@ -602,7 +548,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
metadata=event.metadata,
@ -615,7 +561,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
node_title=self.get_base_node_data().title,
index=event.index,
pre_loop_output=event.pre_loop_output,
)
@ -626,7 +572,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
@ -640,7 +586,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
@ -655,7 +601,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
metadata=event.metadata,
@ -668,7 +614,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
node_title=self.get_base_node_data().title,
index=event.index,
pre_iteration_output=event.pre_iteration_output,
)
@ -679,7 +625,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,
@ -693,7 +639,7 @@ class Node(Generic[NodeDataT]):
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
node_title=self.get_base_node_data().title,
start_at=event.start_at,
inputs=event.inputs,
outputs=event.outputs,

View File

@ -1,9 +1,165 @@
from collections.abc import Mapping
from core.workflow.enums import NodeType
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.trigger_plugin import TriggerEventNode
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
LATEST_VERSION = "latest"
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
# Specifically, if you have introduced new node types, you should add them here.
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,
},
NodeType.END: {
LATEST_VERSION: EndNode,
"1": EndNode,
},
NodeType.ANSWER: {
LATEST_VERSION: AnswerNode,
"1": AnswerNode,
},
NodeType.LLM: {
LATEST_VERSION: LLMNode,
"1": LLMNode,
},
NodeType.KNOWLEDGE_RETRIEVAL: {
LATEST_VERSION: KnowledgeRetrievalNode,
"1": KnowledgeRetrievalNode,
},
NodeType.IF_ELSE: {
LATEST_VERSION: IfElseNode,
"1": IfElseNode,
},
NodeType.CODE: {
LATEST_VERSION: CodeNode,
"1": CodeNode,
},
NodeType.TEMPLATE_TRANSFORM: {
LATEST_VERSION: TemplateTransformNode,
"1": TemplateTransformNode,
},
NodeType.QUESTION_CLASSIFIER: {
LATEST_VERSION: QuestionClassifierNode,
"1": QuestionClassifierNode,
},
NodeType.HTTP_REQUEST: {
LATEST_VERSION: HttpRequestNode,
"1": HttpRequestNode,
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
},
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
}, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: {
LATEST_VERSION: IterationNode,
"1": IterationNode,
},
NodeType.ITERATION_START: {
LATEST_VERSION: IterationStartNode,
"1": IterationStartNode,
},
NodeType.LOOP: {
LATEST_VERSION: LoopNode,
"1": LoopNode,
},
NodeType.LOOP_START: {
LATEST_VERSION: LoopStartNode,
"1": LoopStartNode,
},
NodeType.LOOP_END: {
LATEST_VERSION: LoopEndNode,
"1": LoopEndNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,
},
NodeType.VARIABLE_ASSIGNER: {
LATEST_VERSION: VariableAssignerNodeV2,
"1": VariableAssignerNodeV1,
"2": VariableAssignerNodeV2,
},
NodeType.DOCUMENT_EXTRACTOR: {
LATEST_VERSION: DocumentExtractorNode,
"1": DocumentExtractorNode,
},
NodeType.LIST_OPERATOR: {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
NodeType.AGENT: {
LATEST_VERSION: AgentNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": AgentNode,
"1": AgentNode,
},
NodeType.HUMAN_INPUT: {
LATEST_VERSION: HumanInputNode,
"1": HumanInputNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,
},
NodeType.KNOWLEDGE_INDEX: {
LATEST_VERSION: KnowledgeIndexNode,
"1": KnowledgeIndexNode,
},
NodeType.TRIGGER_WEBHOOK: {
LATEST_VERSION: TriggerWebhookNode,
"1": TriggerWebhookNode,
},
NodeType.TRIGGER_PLUGIN: {
LATEST_VERSION: TriggerEventNode,
"1": TriggerEventNode,
},
NodeType.TRIGGER_SCHEDULE: {
LATEST_VERSION: TriggerScheduleNode,
"1": TriggerScheduleNode,
},
}

View File

@ -12,6 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
@ -429,7 +430,7 @@ class ToolNode(Node[ToolNodeData]):
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
}
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
if usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
@ -448,17 +449,8 @@ class ToolNode(Node[ToolNodeData]):
@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
# Avoid importing WorkflowTool at module import time; rely on duck typing
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
latest = getattr(tool_runtime, "latest_usage", None)
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
# for any name, so we must type-check here.
if isinstance(latest, LLMUsage):
return latest
if isinstance(latest, dict):
# Allow dict payloads from external runtimes
return LLMUsage.model_validate(latest)
# Fallback to empty usage when attribute is missing or not a valid payload
if isinstance(tool_runtime, WorkflowTool):
return tool_runtime.latest_usage
return LLMUsage.empty_usage()
@classmethod

View File

@ -6,7 +6,6 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
def init_app(app: DifyApp):
@ -26,7 +25,6 @@ def init_app(app: DifyApp):
service_api_bp,
allow_headers=list(SERVICE_API_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(service_api_bp)
@ -36,7 +34,7 @@ def init_app(app: DifyApp):
supports_credentials=True,
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=list(EXPOSED_HEADERS),
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(web_bp)
@ -46,7 +44,7 @@ def init_app(app: DifyApp):
supports_credentials=True,
allow_headers=list(AUTHENTICATED_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=list(EXPOSED_HEADERS),
expose_headers=["X-Version", "X-Env"],
)
app.register_blueprint(console_app_bp)
@ -54,7 +52,6 @@ def init_app(app: DifyApp):
files_bp,
allow_headers=list(FILES_HEADERS),
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(files_bp)
@ -66,6 +63,5 @@ def init_app(app: DifyApp):
trigger_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"],
expose_headers=list(EXPOSED_HEADERS),
)
app.register_blueprint(trigger_bp)

View File

@ -1,49 +0,0 @@
import logging
from dify_app import DifyApp
def is_enabled() -> bool:
return True
def init_app(app: DifyApp):
"""Resolve Pydantic forward refs that would otherwise cause circular imports.
Rebuilds models in core.app.entities.app_invoke_entities with the real TraceQueueManager type.
Safe to run multiple times.
"""
logger = logging.getLogger(__name__)
try:
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
AppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
ConversationAppGenerateEntity,
EasyUIBasedAppGenerateEntity,
RagPipelineGenerateEntity,
WorkflowAppGenerateEntity,
)
from core.ops.ops_trace_manager import TraceQueueManager # heavy import, do it at startup only
ns = {"TraceQueueManager": TraceQueueManager}
for Model in (
AppGenerateEntity,
EasyUIBasedAppGenerateEntity,
ConversationAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
WorkflowAppGenerateEntity,
RagPipelineGenerateEntity,
):
try:
Model.model_rebuild(_types_namespace=ns)
except Exception as e:
logger.debug("model_rebuild skipped for %s: %s", Model.__name__, e)
except Exception as e:
# Don't block app startup; just log at debug level.
logger.debug("ext_forward_refs init skipped: %s", e)

View File

@ -7,7 +7,6 @@ from logging.handlers import RotatingFileHandler
import flask
from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp
@ -77,9 +76,7 @@ class RequestIdFilter(logging.Filter):
# the logging format. Note that we're checking if we're in a request
# context, as we may want to log things before Flask is fully loaded.
def filter(self, record):
trace_id = get_trace_id_from_otel_context() or ""
record.req_id = get_request_id() if flask.has_request_context() else ""
record.trace_id = trace_id
return True
@ -87,8 +84,6 @@ class RequestIdFormatter(logging.Formatter):
def format(self, record):
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
record.trace_id = ""
return super().format(record)

View File

@ -1,14 +1,12 @@
import json
import logging
import time
import flask
import werkzeug.http
from flask import Flask, g
from flask import Flask
from flask.signals import request_finished, request_started
from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
logger = logging.getLogger(__name__)
@ -22,9 +20,6 @@ def _is_content_type_json(content_type: str) -> bool:
def _log_request_started(_sender, **_extra):
"""Log the start of a request."""
# Record start time for access logging
g.__request_started_ts = time.perf_counter()
if not logger.isEnabledFor(logging.DEBUG):
return
@ -47,39 +42,8 @@ def _log_request_started(_sender, **_extra):
def _log_request_finished(_sender, response, **_extra):
"""Log the end of a request.
Safe to call with or without an active Flask request context.
"""
if response is None:
return
# Always emit a compact access line at INFO with trace_id so it can be grepped
has_ctx = flask.has_request_context()
start_ts = getattr(g, "__request_started_ts", None) if has_ctx else None
duration_ms = None
if start_ts is not None:
duration_ms = round((time.perf_counter() - start_ts) * 1000, 3)
# Request attributes are available only when a request context exists
if has_ctx:
req_method = flask.request.method
req_path = flask.request.path
else:
req_method = "-"
req_path = "-"
trace_id = get_trace_id_from_otel_context() or response.headers.get("X-Trace-Id") or ""
logger.info(
"%s %s %s %s %s",
req_method,
req_path,
getattr(response, "status_code", "-"),
duration_ms if duration_ms is not None else "-",
trace_id,
)
if not logger.isEnabledFor(logging.DEBUG):
"""Log the end of a request."""
if not logger.isEnabledFor(logging.DEBUG) or response is None:
return
if not _is_content_type_json(response.content_type):

View File

@ -19,7 +19,7 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]):
def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
elif dialect.name in ["postgresql", "mysql"]:
elif dialect.name == "postgresql":
return str(value)
else:
if isinstance(value, uuid.UUID):

View File

@ -111,7 +111,7 @@ package = false
dev = [
"coverage~=7.2.4",
"dotenv-linter~=0.5.0",
"faker~=38.2.0",
"faker~=32.1.0",
"lxml-stubs~=0.5.1",
"ty~=0.0.1a19",
"basedpyright~=1.31.0",

View File

@ -10,7 +10,6 @@ from collections.abc import Sequence
from typing import Any, Literal
import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -1594,176 +1593,173 @@ class DocumentService:
db.session.add(dataset_process_rule)
db.session.flush()
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
try:
with redis_client.lock(lock_name, timeout=600):
assert dataset_process_rule
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
if not knowledge_config.data_source.info_list.file_info_list:
raise ValueError("File source info is required")
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
with redis_client.lock(lock_name, timeout=600):
assert dataset_process_rule
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
if not knowledge_config.data_source.info_list.file_info_list:
raise ValueError("File source info is required")
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info: dict[str, str | bool] = {
"upload_file_id": file_id,
}
# check duplicate
if knowledge_config.duplicate:
document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
)
.first()
)
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info: dict[str, str | bool] = {
"upload_file_id": file_id,
}
# check duplicate
if knowledge_config.duplicate:
document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
)
.first()
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
.all()
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now()
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
data_source_info = {
"credential_id": notion_info.credential_id,
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
"type": page.type,
}
# Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
truncated_page_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
else:
exist_document.pop(page.page_id)
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
if len(url) > 255:
document_name = url[:200] + "..."
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
.all()
)
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
data_source_info = {
"credential_id": notion_info.credential_id,
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
"type": page.type,
}
# Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
truncated_page_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
else:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
document_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
exist_document.pop(page.page_id)
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
if len(url) > 255:
document_name = url[:200] + "..."
else:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
document_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
if document_ids:
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
except LockNotOwnedError:
pass
# trigger async task
if document_ids:
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
return documents, batch
@ -2703,55 +2699,50 @@ class SegmentService:
# calc embedding use tokens
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
lock_name = f"add_segment_lock_document_id_{document.id}"
try:
with redis_client.lock(lock_name, timeout=600):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
status="completed",
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
with redis_client.lock(lock_name, timeout=600):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
status="completed",
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
segment_document.word_count += len(args["answer"])
segment_document.answer = args["answer"]
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()
# save vector index
try:
VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form)
except Exception as e:
logger.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
# save vector index
try:
VectorService.create_segments_vector(
[args["keywords"]], [segment_document], dataset, document.doc_form
)
except Exception as e:
logger.exception("create segment index failed")
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
return segment
except LockNotOwnedError:
pass
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
return segment
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
@ -2760,89 +2751,84 @@ class SegmentService:
lock_name = f"multi_add_segment_lock_document_id_{document.id}"
increment_word_count = 0
try:
with redis_client.lock(lock_name, timeout=600):
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
with redis_client.lock(lock_name, timeout=600):
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
pre_segment_data_list = []
segment_data_list = []
keywords_list = []
position = max_position + 1 if max_position else 1
for segment_item in segments:
content = segment_item["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens
if document.doc_form == "qa_model":
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content + segment_item["answer"]]
)[0]
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=position,
content=content,
word_count=len(content),
tokens=tokens,
keywords=segment_item.get("keywords", []),
status="completed",
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
)
pre_segment_data_list = []
segment_data_list = []
keywords_list = []
position = max_position + 1 if max_position else 1
for segment_item in segments:
content = segment_item["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
tokens = 0
if dataset.indexing_technique == "high_quality" and embedding_model:
# calc embedding use tokens
if document.doc_form == "qa_model":
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
db.session.add(segment_document)
segment_data_list.append(segment_document)
position += 1
pre_segment_data_list.append(segment_document)
if "keywords" in segment_item:
keywords_list.append(segment_item["keywords"])
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[content + segment_item["answer"]]
)[0]
else:
keywords_list.append(None)
# update document word count
assert document.word_count is not None
document.word_count += increment_word_count
db.session.add(document)
try:
# save vector index
VectorService.create_segments_vector(
keywords_list, pre_segment_data_list, dataset, document.doc_form
)
except Exception as e:
logger.exception("create segment index failed")
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
return segment_data_list
except LockNotOwnedError:
pass
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
dataset_id=document.dataset_id,
document_id=document.id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=position,
content=content,
word_count=len(content),
tokens=tokens,
keywords=segment_item.get("keywords", []),
status="completed",
indexing_at=naive_utc_now(),
completed_at=naive_utc_now(),
created_by=current_user.id,
)
if document.doc_form == "qa_model":
segment_document.answer = segment_item["answer"]
segment_document.word_count += len(segment_item["answer"])
increment_word_count += segment_document.word_count
db.session.add(segment_document)
segment_data_list.append(segment_document)
position += 1
pre_segment_data_list.append(segment_document)
if "keywords" in segment_item:
keywords_list.append(segment_item["keywords"])
else:
keywords_list.append(None)
# update document word count
assert document.word_count is not None
document.word_count += increment_word_count
db.session.add(document)
try:
# save vector index
VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form)
except Exception as e:
logger.exception("create segment index failed")
for segment_document in segment_data_list:
segment_document.enabled = False
segment_document.disabled_at = naive_utc_now()
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
return segment_data_list
@classmethod
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):

View File

@ -69,7 +69,6 @@ class ProviderResponse(BaseModel):
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
@ -93,11 +92,6 @@ class ProviderResponse(BaseModel):
self.icon_small = I18nObject(
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_small_dark is not None:
self.icon_small_dark = I18nObject(
en_US=f"{url_prefix}/icon_small_dark/en_US",
zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans",
)
if self.icon_large is not None:
self.icon_large = I18nObject(
@ -115,7 +109,6 @@ class ProviderWithModelsResponse(BaseModel):
provider: str
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
@ -130,11 +123,6 @@ class ProviderWithModelsResponse(BaseModel):
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_small_dark is not None:
self.icon_small_dark = I18nObject(
en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
@ -159,11 +147,6 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
)
if self.icon_small_dark is not None:
self.icon_small_dark = I18nObject(
en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"

View File

@ -79,7 +79,6 @@ class ModelProviderService:
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_small_dark=provider_configuration.provider.icon_small_dark,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
@ -403,7 +402,6 @@ class ModelProviderService:
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_small_dark=first_model.provider.icon_small_dark,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE,
models=[

View File

@ -201,9 +201,7 @@ class ToolTransformService:
@staticmethod
def workflow_provider_to_user_provider(
provider_controller: WorkflowToolProviderController,
labels: list[str] | None = None,
workflow_app_id: str | None = None,
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
):
"""
convert provider controller to user provider
@ -223,7 +221,6 @@ class ToolTransformService:
plugin_unique_identifier=None,
tools=[],
labels=labels or [],
workflow_app_id=workflow_app_id,
)
@staticmethod

View File

@ -189,9 +189,6 @@ class WorkflowToolManageService:
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
# Create a mapping from provider_id to app_id
provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools}
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
try:
@ -205,11 +202,8 @@ class WorkflowToolManageService:
result = []
for tool in tools:
workflow_app_id = provider_id_to_app_id.get(tool.provider_id)
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=tool,
labels=labels.get(tool.provider_id, []),
workflow_app_id=workflow_app_id,
provider_controller=tool, labels=labels.get(tool.provider_id, [])
)
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider)
user_tool_provider.tools = [

View File

@ -233,7 +233,7 @@ workflow:
- value_selector:
- iteration_node
- output
value_type: array[number]
value_type: array[array[number]]
variable: output
selected: false
title: End

View File

@ -227,7 +227,6 @@ class TestModelProviderService:
mock_provider_entity.label = {"en_US": "OpenAI", "zh_Hans": "OpenAI"}
mock_provider_entity.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"}
mock_provider_entity.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}
mock_provider_entity.icon_small_dark = None
mock_provider_entity.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}
mock_provider_entity.background = "#FF6B6B"
mock_provider_entity.help = None
@ -301,7 +300,6 @@ class TestModelProviderService:
mock_provider_entity_llm.label = {"en_US": "OpenAI", "zh_Hans": "OpenAI"}
mock_provider_entity_llm.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"}
mock_provider_entity_llm.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}
mock_provider_entity_llm.icon_small_dark = None
mock_provider_entity_llm.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}
mock_provider_entity_llm.background = "#FF6B6B"
mock_provider_entity_llm.help = None
@ -315,7 +313,6 @@ class TestModelProviderService:
mock_provider_entity_embedding.label = {"en_US": "Cohere", "zh_Hans": "Cohere"}
mock_provider_entity_embedding.description = {"en_US": "Cohere provider", "zh_Hans": "Cohere 提供商"}
mock_provider_entity_embedding.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}
mock_provider_entity_embedding.icon_small_dark = None
mock_provider_entity_embedding.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}
mock_provider_entity_embedding.background = "#4ECDC4"
mock_provider_entity_embedding.help = None
@ -1026,7 +1023,6 @@ class TestModelProviderService:
provider="openai",
label={"en_US": "OpenAI", "zh_Hans": "OpenAI"},
icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"},
icon_small_dark=None,
icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"},
),
model="gpt-3.5-turbo",
@ -1044,7 +1040,6 @@ class TestModelProviderService:
provider="openai",
label={"en_US": "OpenAI", "zh_Hans": "OpenAI"},
icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"},
icon_small_dark=None,
icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"},
),
model="gpt-4",

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,100 +0,0 @@
"""
Unit tests for ToolProviderApiEntity workflow_app_id field.
This test suite covers:
- ToolProviderApiEntity workflow_app_id field creation and default value
- ToolProviderApiEntity.to_dict() method behavior with workflow_app_id
"""
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
class TestToolProviderApiEntityWorkflowAppId:
"""Test suite for ToolProviderApiEntity workflow_app_id field."""
def test_workflow_app_id_field_default_none(self):
"""Test that workflow_app_id defaults to None when not provided."""
entity = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
description=I18nObject(en_US="Test description"),
icon="test_icon",
label=I18nObject(en_US="Test label"),
type=ToolProviderType.WORKFLOW,
)
assert entity.workflow_app_id is None
def test_to_dict_includes_workflow_app_id_when_workflow_type_and_has_value(self):
"""Test that to_dict() includes workflow_app_id when type is WORKFLOW and value is set."""
workflow_app_id = "app_123"
entity = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
description=I18nObject(en_US="Test description"),
icon="test_icon",
label=I18nObject(en_US="Test label"),
type=ToolProviderType.WORKFLOW,
workflow_app_id=workflow_app_id,
)
result = entity.to_dict()
assert "workflow_app_id" in result
assert result["workflow_app_id"] == workflow_app_id
def test_to_dict_excludes_workflow_app_id_when_workflow_type_and_none(self):
"""Test that to_dict() excludes workflow_app_id when type is WORKFLOW but value is None."""
entity = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
description=I18nObject(en_US="Test description"),
icon="test_icon",
label=I18nObject(en_US="Test label"),
type=ToolProviderType.WORKFLOW,
workflow_app_id=None,
)
result = entity.to_dict()
assert "workflow_app_id" not in result
def test_to_dict_excludes_workflow_app_id_when_not_workflow_type(self):
"""Test that to_dict() excludes workflow_app_id when type is not WORKFLOW."""
workflow_app_id = "app_123"
entity = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
description=I18nObject(en_US="Test description"),
icon="test_icon",
label=I18nObject(en_US="Test label"),
type=ToolProviderType.BUILT_IN,
workflow_app_id=workflow_app_id,
)
result = entity.to_dict()
assert "workflow_app_id" not in result
def test_to_dict_includes_workflow_app_id_for_workflow_type_with_empty_string(self):
"""Test that to_dict() excludes workflow_app_id when value is empty string (falsy)."""
entity = ToolProviderApiEntity(
id="test_id",
author="test_author",
name="test_name",
description=I18nObject(en_US="Test description"),
icon="test_icon",
label=I18nObject(en_US="Test label"),
type=ToolProviderType.WORKFLOW,
workflow_app_id="",
)
result = entity.to_dict()
assert "workflow_app_id" not in result

View File

@ -1,5 +1,3 @@
from types import SimpleNamespace
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
@ -216,76 +214,3 @@ def test_create_variable_message():
assert message.message.variable_name == var_name
assert message.message.variable_value == var_value
assert message.message.stream is False
def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.MonkeyPatch):
"""Ensure worker context can resolve EndUser when Account is missing."""
class StubSession:
def __init__(self, results: list):
self.results = results
def scalar(self, _stmt):
return self.results.pop(0)
tenant = SimpleNamespace(id="tenant_id")
end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id")
db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user]))
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="tenant_id", invoke_from=InvokeFrom.SERVICE_API)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
resolved_user = tool._resolve_user_from_database(user_id=end_user.id)
assert resolved_user is end_user
def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pytest.MonkeyPatch):
"""Return None if tenant cannot be found in worker context."""
class StubSession:
def __init__(self, results: list):
self.results = results
def scalar(self, _stmt):
return self.results.pop(0)
db_stub = SimpleNamespace(session=StubSession([None]))
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="missing_tenant", invoke_from=InvokeFrom.SERVICE_API)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
resolved_user = tool._resolve_user_from_database(user_id="any")
assert resolved_user is None

View File

@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]):
@classmethod
def version(cls) -> str:
return "1"
return "test"
def __init__(
self,

View File

@ -744,7 +744,7 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered():
)
llm_node = graph.nodes["llm"]
base_node_data = llm_node.node_data
base_node_data = llm_node.get_base_node_data()
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]

View File

@ -7,31 +7,9 @@ This module tests the iteration node's ability to:
"""
from .test_database_utils import skip_if_database_unavailable
from .test_mock_config import MockConfigBuilder, NodeMockConfig
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _create_iteration_mock_config():
"""Helper to create a mock config for iteration tests."""
def code_inner_handler(node):
pool = node.graph_runtime_state.variable_pool
item_seg = pool.get(["iteration_node", "item"])
if item_seg is not None:
item = item_seg.to_object()
return {"result": [item, item * 2]}
# This fallback is likely unreachable, but if it is,
# it doesn't simulate iteration with different values as the comment suggests.
return {"result": [1, 2]}
return (
MockConfigBuilder()
.with_node_output("code_node", {"result": [1, 2, 3]})
.with_node_config(NodeMockConfig(node_id="code_inner_node", custom_handler=code_inner_handler))
.build()
)
@skip_if_database_unavailable()
def test_iteration_with_flatten_output_enabled():
"""
@ -49,8 +27,7 @@ def test_iteration_with_flatten_output_enabled():
inputs={},
expected_outputs={"output": [1, 2, 2, 4, 3, 6]},
description="Iteration with flatten_output=True flattens nested arrays",
use_auto_mock=True, # Use auto-mock to avoid sandbox service
mock_config=_create_iteration_mock_config(),
use_auto_mock=False, # Run code nodes directly
)
result = runner.run_test_case(test_case)
@ -79,8 +56,7 @@ def test_iteration_with_flatten_output_disabled():
inputs={},
expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]},
description="Iteration with flatten_output=False preserves nested structure",
use_auto_mock=True, # Use auto-mock to avoid sandbox service
mock_config=_create_iteration_mock_config(),
use_auto_mock=False, # Run code nodes directly
)
result = runner.run_test_case(test_case)
@ -105,16 +81,14 @@ def test_iteration_flatten_output_comparison():
inputs={},
expected_outputs={"output": [1, 2, 2, 4, 3, 6]},
description="flatten_output=True: Flattened output",
use_auto_mock=True, # Use auto-mock to avoid sandbox service
mock_config=_create_iteration_mock_config(),
use_auto_mock=False, # Run code nodes directly
),
WorkflowTestCase(
fixture_path="iteration_flatten_output_disabled_workflow",
inputs={},
expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]},
description="flatten_output=False: Nested output",
use_auto_mock=True, # Use auto-mock to avoid sandbox service
mock_config=_create_iteration_mock_config(),
use_auto_mock=False, # Run code nodes directly
),
]

View File

@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _run(self) -> Generator:
"""Execute mock LLM node."""
@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _run(self) -> Generator:
"""Execute mock agent node."""
@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _run(self) -> Generator:
"""Execute mock tool node."""
@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _run(self) -> Generator:
"""Execute mock knowledge retrieval node."""
@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _run(self) -> Generator:
"""Execute mock HTTP request node."""
@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _run(self) -> Generator:
"""Execute mock question classifier node."""
@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _run(self) -> Generator:
"""Execute mock parameter extractor node."""
@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _run(self) -> Generator:
"""Execute mock document extractor node."""
@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _create_graph_engine(self, index: int, item: Any):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _create_graph_engine(self, start_at, root_node_id: str):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _run(self) -> NodeRunResult:
"""Execute mock template transform node."""
@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "1"
return "mock-1"
def _run(self) -> NodeRunResult:
"""Execute mock code node."""

View File

@ -33,10 +33,6 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
type_version_set: set[tuple[NodeType, str]] = set()
for cls in classes:
# Only validate production node classes; skip test-defined subclasses and external helpers
module_name = getattr(cls, "__module__", "")
if not module_name.startswith("core."):
continue
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
node_type = cls.node_type

View File

@ -1,84 +0,0 @@
import types
from collections.abc import Mapping
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import Node
# Import concrete nodes we will assert on (numeric version path)
from core.workflow.nodes.variable_assigner.v1.node import (
VariableAssignerNode as VariableAssignerV1,
)
from core.workflow.nodes.variable_assigner.v2.node import (
VariableAssignerNode as VariableAssignerV2,
)
def test_variable_assigner_latest_prefers_highest_numeric_version():
# Act
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
# Assert basic presence
assert NodeType.VARIABLE_ASSIGNER in mapping
va_versions = mapping[NodeType.VARIABLE_ASSIGNER]
# Both concrete versions must be present
assert va_versions.get("1") is VariableAssignerV1
assert va_versions.get("2") is VariableAssignerV2
# And latest should point to numerically-highest version ("2")
assert va_versions.get("latest") is VariableAssignerV2
def test_latest_prefers_highest_numeric_version():
# Arrange: define two ephemeral subclasses with numeric versions under a NodeType
# that has no concrete implementations in production to avoid interference.
class _Version1(Node[BaseNodeData]): # type: ignore[misc]
node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR
def init_node_data(self, data):
pass
def _run(self):
raise NotImplementedError
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self):
return None
def _get_retry_config(self):
return types.SimpleNamespace() # not used
def _get_title(self) -> str:
return "version1"
def _get_description(self):
return None
def _get_default_value_dict(self):
return {}
def get_base_node_data(self):
return types.SimpleNamespace(title="version1")
class _Version2(_Version1): # type: ignore[misc]
@classmethod
def version(cls) -> str:
return "2"
def _get_title(self) -> str:
return "version2"
# Act: build a fresh mapping (it should now see our ephemeral subclasses)
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
# Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version
assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping
legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR]
assert legacy_versions.get("1") is _Version1
assert legacy_versions.get("2") is _Version2
assert legacy_versions.get("latest") is _Version2

View File

@ -471,8 +471,8 @@ class TestCodeNodeInitialization:
assert node._get_description() is None
def test_node_data_property(self):
"""Test node_data property returns node data."""
def test_get_base_node_data(self):
"""Test get_base_node_data returns node data."""
node = CodeNode.__new__(CodeNode)
node._node_data = CodeNodeData(
title="Base Test",
@ -482,7 +482,7 @@ class TestCodeNodeInitialization:
outputs={},
)
result = node.node_data
result = node.get_base_node_data()
assert result == node._node_data
assert result.title == "Base Test"

View File

@ -240,8 +240,8 @@ class TestIterationNodeInitialization:
assert node._get_description() == "This is a description"
def test_node_data_property(self):
"""Test node_data property returns node data."""
def test_get_base_node_data(self):
"""Test get_base_node_data returns node data."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Base Test",
@ -249,7 +249,7 @@ class TestIterationNodeInitialization:
output_selector=["y"],
)
result = node.node_data
result = node.get_base_node_data()
assert result == node._node_data

View File

@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]):
@classmethod
def version(cls) -> str:
return "1"
return "sample-test"
def _run(self):
raise NotImplementedError

View File

@ -263,62 +263,3 @@ class TestResponseUnmodified:
)
assert response.text == _RESPONSE_NEEDLE
assert response.status_code == 200
class TestRequestFinishedInfoAccessLine:
def test_info_access_log_includes_method_path_status_duration_trace_id(self, monkeypatch, caplog):
"""Ensure INFO access line contains expected fields with computed duration and trace id."""
app = _get_test_app()
# Push a real request context so flask.request and g are available
with app.test_request_context("/foo", method="GET"):
# Seed start timestamp via the extension's own start hook and control perf_counter deterministically
seq = iter([100.0, 100.123456])
monkeypatch.setattr(ext_request_logging.time, "perf_counter", lambda: next(seq))
# Provide a deterministic trace id
monkeypatch.setattr(
ext_request_logging,
"get_trace_id_from_otel_context",
lambda: "trace-xyz",
)
# Simulate request_started to record start timestamp on g
ext_request_logging._log_request_started(app)
# Capture logs from the real logger at INFO level only (skip DEBUG branch)
caplog.set_level(logging.INFO, logger=ext_request_logging.__name__)
response = Response(json.dumps({"ok": True}), mimetype="application/json", status=200)
_log_request_finished(app, response)
# Verify a single INFO record with the five fields in order
info_records = [rec for rec in caplog.records if rec.levelno == logging.INFO]
assert len(info_records) == 1
msg = info_records[0].getMessage()
# Expected format: METHOD PATH STATUS DURATION_MS TRACE_ID
assert "GET" in msg
assert "/foo" in msg
assert "200" in msg
assert "123.456" in msg # rounded to 3 decimals
assert "trace-xyz" in msg
def test_info_access_log_uses_dash_without_start_timestamp(self, monkeypatch, caplog):
app = _get_test_app()
with app.test_request_context("/bar", method="POST"):
# No g.__request_started_ts set -> duration should be '-'
monkeypatch.setattr(
ext_request_logging,
"get_trace_id_from_otel_context",
lambda: "tid-no-start",
)
caplog.set_level(logging.INFO, logger=ext_request_logging.__name__)
response = Response("OK", mimetype="text/plain", status=204)
_log_request_finished(app, response)
info_records = [rec for rec in caplog.records if rec.levelno == logging.INFO]
assert len(info_records) == 1
msg = info_records[0].getMessage()
assert "POST" in msg
assert "/bar" in msg
assert "204" in msg
# Duration placeholder
# The fields are space separated; ensure a standalone '-' appears
assert " - " in msg or msg.endswith(" -")
assert "tid-no-start" in msg

View File

@ -1,718 +0,0 @@
"""
Comprehensive unit tests for AudioService.
This test suite provides complete coverage of audio processing operations in Dify,
following TDD principles with the Arrange-Act-Assert pattern.
## Test Coverage
### 1. Speech-to-Text (ASR) Operations (TestAudioServiceASR)
Tests audio transcription functionality:
- Successful transcription for different app modes
- File validation (size, type, presence)
- Feature flag validation (speech-to-text enabled)
- Error handling for various failure scenarios
- Model instance availability checks
### 2. Text-to-Speech (TTS) Operations (TestAudioServiceTTS)
Tests text-to-audio conversion:
- TTS with text input
- TTS with message ID
- Voice selection (explicit and default)
- Feature flag validation (text-to-speech enabled)
- Draft workflow handling
- Streaming response handling
- Error handling for missing/invalid inputs
### 3. TTS Voice Listing (TestAudioServiceTTSVoices)
Tests available voice retrieval:
- Get available voices for a tenant
- Language filtering
- Error handling for missing provider
## Testing Approach
- **Mocking Strategy**: All external dependencies (ModelManager, db, FileStorage) are mocked
for fast, isolated unit tests
- **Factory Pattern**: AudioServiceTestDataFactory provides consistent test data
- **Fixtures**: Mock objects are configured per test method
- **Assertions**: Each test verifies return values, side effects, and error conditions
## Key Concepts
**Audio Formats:**
- Supported: mp3, wav, m4a, flac, ogg, opus, webm
- File size limit: 30 MB
**App Modes:**
- ADVANCED_CHAT/WORKFLOW: Use workflow features
- CHAT/COMPLETION: Use app_model_config
**Feature Flags:**
- speech_to_text: Enables ASR functionality
- text_to_speech: Enables TTS functionality
"""
from unittest.mock import MagicMock, Mock, create_autospec, patch
import pytest
from werkzeug.datastructures import FileStorage
from models.enums import MessageStatus
from models.model import App, AppMode, AppModelConfig, Message
from models.workflow import Workflow
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
ProviderNotSupportTextToSpeechServiceError,
UnsupportedAudioTypeServiceError,
)
class AudioServiceTestDataFactory:
"""
Factory for creating test data and mock objects.
Provides reusable methods to create consistent mock objects for testing
audio-related operations.
"""
@staticmethod
def create_app_mock(
app_id: str = "app-123",
mode: AppMode = AppMode.CHAT,
tenant_id: str = "tenant-123",
**kwargs,
) -> Mock:
"""
Create a mock App object.
Args:
app_id: Unique identifier for the app
mode: App mode (CHAT, ADVANCED_CHAT, WORKFLOW, etc.)
tenant_id: Tenant identifier
**kwargs: Additional attributes to set on the mock
Returns:
Mock App object with specified attributes
"""
app = create_autospec(App, instance=True)
app.id = app_id
app.mode = mode
app.tenant_id = tenant_id
app.workflow = kwargs.get("workflow")
app.app_model_config = kwargs.get("app_model_config")
for key, value in kwargs.items():
setattr(app, key, value)
return app
@staticmethod
def create_workflow_mock(features_dict: dict | None = None, **kwargs) -> Mock:
"""
Create a mock Workflow object.
Args:
features_dict: Dictionary of workflow features
**kwargs: Additional attributes to set on the mock
Returns:
Mock Workflow object with specified attributes
"""
workflow = create_autospec(Workflow, instance=True)
workflow.features_dict = features_dict or {}
for key, value in kwargs.items():
setattr(workflow, key, value)
return workflow
@staticmethod
def create_app_model_config_mock(
speech_to_text_dict: dict | None = None,
text_to_speech_dict: dict | None = None,
**kwargs,
) -> Mock:
"""
Create a mock AppModelConfig object.
Args:
speech_to_text_dict: Speech-to-text configuration
text_to_speech_dict: Text-to-speech configuration
**kwargs: Additional attributes to set on the mock
Returns:
Mock AppModelConfig object with specified attributes
"""
config = create_autospec(AppModelConfig, instance=True)
config.speech_to_text_dict = speech_to_text_dict or {"enabled": False}
config.text_to_speech_dict = text_to_speech_dict or {"enabled": False}
for key, value in kwargs.items():
setattr(config, key, value)
return config
@staticmethod
def create_file_storage_mock(
filename: str = "test.mp3",
mimetype: str = "audio/mp3",
content: bytes = b"fake audio content",
**kwargs,
) -> Mock:
"""
Create a mock FileStorage object.
Args:
filename: Name of the file
mimetype: MIME type of the file
content: File content as bytes
**kwargs: Additional attributes to set on the mock
Returns:
Mock FileStorage object with specified attributes
"""
file = Mock(spec=FileStorage)
file.filename = filename
file.mimetype = mimetype
file.read = Mock(return_value=content)
for key, value in kwargs.items():
setattr(file, key, value)
return file
@staticmethod
def create_message_mock(
message_id: str = "msg-123",
answer: str = "Test answer",
status: MessageStatus = MessageStatus.NORMAL,
**kwargs,
) -> Mock:
"""
Create a mock Message object.
Args:
message_id: Unique identifier for the message
answer: Message answer text
status: Message status
**kwargs: Additional attributes to set on the mock
Returns:
Mock Message object with specified attributes
"""
message = create_autospec(Message, instance=True)
message.id = message_id
message.answer = answer
message.status = status
for key, value in kwargs.items():
setattr(message, key, value)
return message
@pytest.fixture
def factory():
"""Provide the test data factory to all tests."""
return AudioServiceTestDataFactory
class TestAudioServiceASR:
"""Test speech-to-text (ASR) operations."""
@patch("services.audio_service.ModelManager")
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
"""Test successful ASR transcription in CHAT mode."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
file = factory.create_file_storage_mock()
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_speech2text.return_value = "Transcribed text"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_asr(app_model=app, file=file, end_user="user-123")
# Assert
assert result == {"text": "Transcribed text"}
mock_model_instance.invoke_speech2text.assert_called_once()
call_args = mock_model_instance.invoke_speech2text.call_args
assert call_args.kwargs["user"] == "user-123"
@patch("services.audio_service.ModelManager")
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
# Arrange
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": True}})
app = factory.create_app_mock(
mode=AppMode.ADVANCED_CHAT,
workflow=workflow,
)
file = factory.create_file_storage_mock()
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_speech2text.return_value = "Workflow transcribed text"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_asr(app_model=app, file=file)
# Assert
assert result == {"text": "Workflow transcribed text"}
def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory):
"""Test that ASR raises error when speech-to-text is disabled in CHAT mode."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": False})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
file = factory.create_file_storage_mock()
# Act & Assert
with pytest.raises(ValueError, match="Speech to text is not enabled"):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(self, factory):
"""Test that ASR raises error when speech-to-text is disabled in WORKFLOW mode."""
# Arrange
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": False}})
app = factory.create_app_mock(
mode=AppMode.WORKFLOW,
workflow=workflow,
)
file = factory.create_file_storage_mock()
# Act & Assert
with pytest.raises(ValueError, match="Speech to text is not enabled"):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_when_workflow_missing(self, factory):
"""Test that ASR raises error when workflow is missing in WORKFLOW mode."""
# Arrange
app = factory.create_app_mock(
mode=AppMode.WORKFLOW,
workflow=None,
)
file = factory.create_file_storage_mock()
# Act & Assert
with pytest.raises(ValueError, match="Speech to text is not enabled"):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory):
"""Test that ASR raises error when no file is uploaded."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Act & Assert
with pytest.raises(NoAudioUploadedServiceError):
AudioService.transcript_asr(app_model=app, file=None)
def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory):
"""Test that ASR raises error for unsupported audio file types."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
file = factory.create_file_storage_mock(mimetype="video/mp4")
# Act & Assert
with pytest.raises(UnsupportedAudioTypeServiceError):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_for_large_file(self, factory):
"""Test that ASR raises error when file exceeds size limit (30MB)."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Create file larger than 30MB
large_content = b"x" * (31 * 1024 * 1024)
file = factory.create_file_storage_mock(content=large_content)
# Act & Assert
with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"):
AudioService.transcript_asr(app_model=app, file=file)
@patch("services.audio_service.ModelManager")
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
"""Test that ASR raises error when no model instance is available."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
file = factory.create_file_storage_mock()
# Mock ModelManager to return None
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager.get_default_model_instance.return_value = None
# Act & Assert
with pytest.raises(ProviderNotSupportSpeechToTextServiceError):
AudioService.transcript_asr(app_model=app, file=file)
class TestAudioServiceTTS:
"""Test text-to-speech (TTS) operations."""
@patch("services.audio_service.ModelManager")
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
"""Test successful TTS with text input."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
text_to_speech_dict={"enabled": True, "voice": "en-US-Neural"}
)
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio data"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
text="Hello world",
voice="en-US-Neural",
end_user="user-123",
)
# Assert
assert result == b"audio data"
mock_model_instance.invoke_tts.assert_called_once_with(
content_text="Hello world",
user="user-123",
tenant_id=app.tenant_id,
voice="en-US-Neural",
)
@patch("services.audio_service.db.session")
@patch("services.audio_service.ModelManager")
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
"""Test successful TTS with message ID."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
text_to_speech_dict={"enabled": True, "voice": "en-US-Neural"}
)
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
message = factory.create_message_mock(
message_id="550e8400-e29b-41d4-a716-446655440000",
answer="Message answer text",
)
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = message
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio from message"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
message_id="550e8400-e29b-41d4-a716-446655440000",
)
# Assert
assert result == b"audio from message"
mock_model_instance.invoke_tts.assert_called_once()
@patch("services.audio_service.ModelManager")
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
"""Test TTS uses default voice when none specified."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
text_to_speech_dict={"enabled": True, "voice": "default-voice"}
)
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio data"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
text="Test",
)
# Assert
assert result == b"audio data"
# Verify default voice was used
call_args = mock_model_instance.invoke_tts.call_args
assert call_args.kwargs["voice"] == "default-voice"
@patch("services.audio_service.ModelManager")
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
"""Test TTS gets first available voice when none is configured."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
text_to_speech_dict={"enabled": True} # No voice specified
)
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = [{"value": "auto-voice"}]
mock_model_instance.invoke_tts.return_value = b"audio data"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
text="Test",
)
# Assert
assert result == b"audio data"
call_args = mock_model_instance.invoke_tts.call_args
assert call_args.kwargs["voice"] == "auto-voice"
@patch("services.audio_service.WorkflowService")
@patch("services.audio_service.ModelManager")
def test_transcript_tts_workflow_mode_with_draft(
self, mock_model_manager_class, mock_workflow_service_class, factory
):
"""Test TTS in WORKFLOW mode with draft workflow."""
# Arrange
draft_workflow = factory.create_workflow_mock(
features_dict={"text_to_speech": {"enabled": True, "voice": "draft-voice"}}
)
app = factory.create_app_mock(
mode=AppMode.WORKFLOW,
)
# Mock WorkflowService
mock_workflow_service = MagicMock()
mock_workflow_service_class.return_value = mock_workflow_service
mock_workflow_service.get_draft_workflow.return_value = draft_workflow
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"draft audio"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
text="Draft test",
is_draft=True,
)
# Assert
assert result == b"draft audio"
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app)
def test_transcript_tts_raises_error_when_text_missing(self, factory):
"""Test that TTS raises error when text is missing."""
# Arrange
app = factory.create_app_mock()
# Act & Assert
with pytest.raises(ValueError, match="Text is required"):
AudioService.transcript_tts(app_model=app, text=None)
@patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
"""Test that TTS returns None for invalid message ID format."""
# Arrange
app = factory.create_app_mock()
# Act
result = AudioService.transcript_tts(
app_model=app,
message_id="invalid-uuid",
)
# Assert
assert result is None
@patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
"""Test that TTS returns None when message doesn't exist."""
# Arrange
app = factory.create_app_mock()
# Mock database query returning None
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Act
result = AudioService.transcript_tts(
app_model=app,
message_id="550e8400-e29b-41d4-a716-446655440000",
)
# Assert
assert result is None
@patch("services.audio_service.db.session")
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
"""Test that TTS returns None when message answer is empty."""
# Arrange
app = factory.create_app_mock()
message = factory.create_message_mock(
answer="",
status=MessageStatus.NORMAL,
)
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = message
# Act
result = AudioService.transcript_tts(
app_model=app,
message_id="550e8400-e29b-41d4-a716-446655440000",
)
# Assert
assert result is None
@patch("services.audio_service.ModelManager")
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
"""Test that TTS raises error when no voices are available."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
text_to_speech_dict={"enabled": True} # No voice specified
)
app = factory.create_app_mock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = [] # No voices available
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act & Assert
with pytest.raises(ValueError, match="Sorry, no voice available"):
AudioService.transcript_tts(app_model=app, text="Test")
class TestAudioServiceTTSVoices:
"""Test TTS voice listing operations."""
@patch("services.audio_service.ModelManager")
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
"""Test successful retrieval of TTS voices."""
# Arrange
tenant_id = "tenant-123"
language = "en-US"
expected_voices = [
{"name": "Voice 1", "value": "voice-1"},
{"name": "Voice 2", "value": "voice-2"},
]
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = expected_voices
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
# Assert
assert result == expected_voices
mock_model_instance.get_tts_voices.assert_called_once_with(language)
@patch("services.audio_service.ModelManager")
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
"""Test that TTS voices raises error when no model instance is available."""
# Arrange
tenant_id = "tenant-123"
language = "en-US"
# Mock ModelManager to return None
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager.get_default_model_instance.return_value = None
# Act & Assert
with pytest.raises(ProviderNotSupportTextToSpeechServiceError):
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
@patch("services.audio_service.ModelManager")
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
"""Test that TTS voices propagates exceptions from model instance."""
# Arrange
tenant_id = "tenant-123"
language = "en-US"
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.side_effect = RuntimeError("Model error")
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act & Assert
with pytest.raises(RuntimeError, match="Model error"):
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)

View File

@ -1,177 +0,0 @@
import types
from unittest.mock import Mock, create_autospec
import pytest
from redis.exceptions import LockNotOwnedError
from models.account import Account
from models.dataset import Dataset, Document
from services.dataset_service import DocumentService, SegmentService
class FakeLock:
"""Lock that always fails on enter with LockNotOwnedError."""
def __enter__(self):
raise LockNotOwnedError("simulated")
def __exit__(self, exc_type, exc, tb):
# Normal contextmanager signature; return False so exceptions propagate
return False
@pytest.fixture
def fake_current_user(monkeypatch):
user = create_autospec(Account, instance=True)
user.id = "user-1"
user.current_tenant_id = "tenant-1"
monkeypatch.setattr("services.dataset_service.current_user", user)
return user
@pytest.fixture
def fake_features(monkeypatch):
"""Features.billing.enabled == False to skip quota logic."""
features = types.SimpleNamespace(
billing=types.SimpleNamespace(enabled=False, subscription=types.SimpleNamespace(plan="ENTERPRISE")),
documents_upload_quota=types.SimpleNamespace(limit=10_000, size=0),
)
monkeypatch.setattr(
"services.dataset_service.FeatureService.get_features",
lambda tenant_id: features,
)
return features
@pytest.fixture
def fake_lock(monkeypatch):
"""Patch redis_client.lock to always raise LockNotOwnedError on enter."""
def _fake_lock(name, timeout=None, *args, **kwargs):
return FakeLock()
# DatasetService imports redis_client directly from extensions.ext_redis
monkeypatch.setattr("services.dataset_service.redis_client.lock", _fake_lock)
# ---------------------------------------------------------------------------
# 1. Knowledge Pipeline document creation (save_document_with_dataset_id)
# ---------------------------------------------------------------------------
def test_save_document_with_dataset_id_ignores_lock_not_owned(
monkeypatch,
fake_current_user,
fake_features,
fake_lock,
):
# Arrange
dataset = create_autospec(Dataset, instance=True)
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.data_source_type = "upload_file"
dataset.indexing_technique = "high_quality" # so we skip re-initialization branch
# Minimal knowledge_config stub that satisfies pre-lock code
info_list = types.SimpleNamespace(data_source_type="upload_file")
data_source = types.SimpleNamespace(info_list=info_list)
knowledge_config = types.SimpleNamespace(
doc_form="qa_model",
original_document_id=None, # go into "new document" branch
data_source=data_source,
indexing_technique="high_quality",
embedding_model=None,
embedding_model_provider=None,
retrieval_model=None,
process_rule=None,
duplicate=False,
doc_language="en",
)
account = fake_current_user
# Avoid touching real doc_form logic
monkeypatch.setattr("services.dataset_service.DatasetService.check_doc_form", lambda *a, **k: None)
# Avoid real DB interactions
monkeypatch.setattr("services.dataset_service.db", Mock())
# Act: this would hit the redis lock, whose __enter__ raises LockNotOwnedError.
# Our implementation should catch it and still return (documents, batch).
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=account,
)
# Assert
# We mainly care that:
# - No exception is raised
# - The function returns a sensible tuple
assert isinstance(documents, list)
assert isinstance(batch, str)
# ---------------------------------------------------------------------------
# 2. Single-segment creation (add_segment)
# ---------------------------------------------------------------------------
def test_add_segment_ignores_lock_not_owned(
monkeypatch,
fake_current_user,
fake_lock,
):
# Arrange
dataset = create_autospec(Dataset, instance=True)
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.indexing_technique = "economy" # skip embedding/token calculation branch
document = create_autospec(Document, instance=True)
document.id = "doc-1"
document.dataset_id = dataset.id
document.word_count = 0
document.doc_form = "qa_model"
# Minimal args required by add_segment
args = {
"content": "question text",
"answer": "answer text",
"keywords": ["k1", "k2"],
}
# Avoid real DB operations
db_mock = Mock()
db_mock.session = Mock()
monkeypatch.setattr("services.dataset_service.db", db_mock)
monkeypatch.setattr("services.dataset_service.VectorService", Mock())
# Act
result = SegmentService.create_segment(args=args, document=document, dataset=dataset)
# Assert
# Under LockNotOwnedError except, add_segment should swallow the error and return None.
assert result is None
# ---------------------------------------------------------------------------
# 3. Multi-segment creation (multi_create_segment)
# ---------------------------------------------------------------------------
def test_multi_create_segment_ignores_lock_not_owned(
monkeypatch,
fake_current_user,
fake_lock,
):
# Arrange
dataset = create_autospec(Dataset, instance=True)
dataset.id = "ds-1"
dataset.tenant_id = fake_current_user.current_tenant_id
dataset.indexing_technique = "economy" # again, skip high_quality path
document = create_autospec(Document, instance=True)
document.id = "doc-1"
document.dataset_id = dataset.id
document.word_count = 0
document.doc_form = "qa_model"

File diff suppressed because it is too large Load Diff

View File

@ -1,440 +0,0 @@
"""
Comprehensive unit tests for RecommendedAppService.
This test suite provides complete coverage of recommended app operations in Dify,
following TDD principles with the Arrange-Act-Assert pattern.
## Test Coverage
### 1. Get Recommended Apps and Categories (TestRecommendedAppServiceGetApps)
Tests fetching recommended apps with categories:
- Successful retrieval with recommended apps
- Fallback to builtin when no recommended apps
- Different language support
- Factory mode selection (remote, builtin, db)
- Empty result handling
### 2. Get Recommend App Detail (TestRecommendedAppServiceGetDetail)
Tests fetching individual app details:
- Successful app detail retrieval
- Different factory modes
- App not found scenarios
- Language-specific details
## Testing Approach
- **Mocking Strategy**: All external dependencies (dify_config, RecommendAppRetrievalFactory)
are mocked for fast, isolated unit tests
- **Factory Pattern**: Tests verify correct factory selection based on mode
- **Fixtures**: Mock objects are configured per test method
- **Assertions**: Each test verifies return values and factory method calls
## Key Concepts
**Factory Modes:**
- remote: Fetch from remote API
- builtin: Use built-in templates
- db: Fetch from database
**Fallback Logic:**
- If remote/db returns no apps, fallback to builtin en-US templates
- Ensures users always see some recommended apps
"""
from unittest.mock import MagicMock, patch
import pytest
from services.recommended_app_service import RecommendedAppService
class RecommendedAppServiceTestDataFactory:
"""
Factory for creating test data and mock objects.
Provides reusable methods to create consistent mock objects for testing
recommended app operations.
"""
@staticmethod
def create_recommended_apps_response(
recommended_apps: list[dict] | None = None,
categories: list[str] | None = None,
) -> dict:
"""
Create a mock response for recommended apps.
Args:
recommended_apps: List of recommended app dictionaries
categories: List of category names
Returns:
Dictionary with recommended_apps and categories
"""
if recommended_apps is None:
recommended_apps = [
{
"id": "app-1",
"name": "Test App 1",
"description": "Test description 1",
"category": "productivity",
},
{
"id": "app-2",
"name": "Test App 2",
"description": "Test description 2",
"category": "communication",
},
]
if categories is None:
categories = ["productivity", "communication", "utilities"]
return {
"recommended_apps": recommended_apps,
"categories": categories,
}
@staticmethod
def create_app_detail_response(
app_id: str = "app-123",
name: str = "Test App",
description: str = "Test description",
**kwargs,
) -> dict:
"""
Create a mock response for app detail.
Args:
app_id: App identifier
name: App name
description: App description
**kwargs: Additional fields
Returns:
Dictionary with app details
"""
detail = {
"id": app_id,
"name": name,
"description": description,
"category": kwargs.get("category", "productivity"),
"icon": kwargs.get("icon", "🚀"),
"model_config": kwargs.get("model_config", {}),
}
detail.update(kwargs)
return detail
@pytest.fixture
def factory():
"""Provide the test data factory to all tests."""
return RecommendedAppServiceTestDataFactory
class TestRecommendedAppServiceGetApps:
"""Test get_recommended_apps_and_categories operations."""
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
"""Test successful retrieval of recommended apps when apps are returned."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
expected_response = factory.create_recommended_apps_response()
# Mock factory and retrieval instance
mock_retrieval_instance = MagicMock()
mock_retrieval_instance.get_recommended_apps_and_categories.return_value = expected_response
mock_factory = MagicMock()
mock_factory.return_value = mock_retrieval_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
assert result == expected_response
assert len(result["recommended_apps"]) == 2
assert len(result["categories"]) == 3
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
"""Test fallback to builtin when no recommended apps are returned."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
# Remote returns empty recommended_apps
empty_response = {"recommended_apps": [], "categories": []}
# Builtin fallback response
builtin_response = factory.create_recommended_apps_response(
recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}]
)
# Mock remote retrieval instance (returns empty)
mock_remote_instance = MagicMock()
mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response
mock_remote_factory = MagicMock()
mock_remote_factory.return_value = mock_remote_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_remote_factory
# Mock builtin retrieval instance
mock_builtin_instance = MagicMock()
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
# Assert
assert result == builtin_response
assert len(result["recommended_apps"]) == 1
assert result["recommended_apps"][0]["id"] == "builtin-1"
# Verify fallback was called with en-US (hardcoded)
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
"""Test fallback when recommended_apps key is None."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
# Response with None recommended_apps
none_response = {"recommended_apps": None, "categories": ["test"]}
# Builtin fallback response
builtin_response = factory.create_recommended_apps_response()
# Mock db retrieval instance (returns None)
mock_db_instance = MagicMock()
mock_db_instance.get_recommended_apps_and_categories.return_value = none_response
mock_db_factory = MagicMock()
mock_db_factory.return_value = mock_db_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_db_factory
# Mock builtin retrieval instance
mock_builtin_instance = MagicMock()
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
# Act
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
assert result == builtin_response
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
"""Test retrieval with different language codes."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
languages = ["en-US", "zh-CN", "ja-JP", "fr-FR"]
for language in languages:
# Create language-specific response
lang_response = factory.create_recommended_apps_response(
recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}]
)
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommended_apps_and_categories(language)
# Assert
assert result["recommended_apps"][0]["id"] == f"app-{language}"
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
"""Test that correct factory is selected based on mode."""
# Arrange
modes = ["remote", "builtin", "db"]
for mode in modes:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
response = factory.create_recommended_apps_response()
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommended_apps_and_categories.return_value = response
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
RecommendedAppService.get_recommended_apps_and_categories("en-US")
# Assert
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
class TestRecommendedAppServiceGetDetail:
"""Test get_recommend_app_detail operations."""
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
"""Test successful retrieval of app detail."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
app_id = "app-123"
expected_detail = factory.create_app_detail_response(
app_id=app_id,
name="Productivity App",
description="A great productivity app",
category="productivity",
)
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = expected_detail
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
# Assert
assert result == expected_detail
assert result["id"] == app_id
assert result["name"] == "Productivity App"
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
"""Test app detail retrieval with different factory modes."""
# Arrange
modes = ["remote", "builtin", "db"]
app_id = "test-app"
for mode in modes:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
detail = factory.create_app_detail_response(app_id=app_id, name=f"App from {mode}")
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = detail
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
# Assert
assert result["name"] == f"App from {mode}"
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
"""Test that None is returned when app is not found."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
app_id = "nonexistent-app"
# Mock retrieval instance returning None
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = None
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
# Assert
assert result is None
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
"""Test handling of empty dict response."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
app_id = "app-empty"
# Mock retrieval instance returning empty dict
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = {}
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
# Assert
assert result == {}
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
"""Test app detail with complex model configuration."""
# Arrange
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
app_id = "complex-app"
complex_model_config = {
"provider": "openai",
"model": "gpt-4",
"parameters": {
"temperature": 0.7,
"max_tokens": 2000,
"top_p": 1.0,
},
}
expected_detail = factory.create_app_detail_response(
app_id=app_id,
name="Complex App",
model_config=complex_model_config,
workflows=["workflow-1", "workflow-2"],
tools=["tool-1", "tool-2", "tool-3"],
)
# Mock retrieval instance
mock_instance = MagicMock()
mock_instance.get_recommend_app_detail.return_value = expected_detail
mock_factory = MagicMock()
mock_factory.return_value = mock_instance
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
# Act
result = RecommendedAppService.get_recommend_app_detail(app_id)
# Assert
assert result["model_config"] == complex_model_config
assert len(result["workflows"]) == 2
assert len(result["tools"]) == 3

View File

@ -1,626 +0,0 @@
"""
Comprehensive unit tests for SavedMessageService.
This test suite provides complete coverage of saved message operations in Dify,
following TDD principles with the Arrange-Act-Assert pattern.
## Test Coverage
### 1. Pagination (TestSavedMessageServicePagination)
Tests saved message listing and pagination:
- Pagination with valid user (Account and EndUser)
- Pagination without user raises ValueError
- Pagination with last_id parameter
- Empty results when no saved messages exist
- Integration with MessageService pagination
### 2. Save Operations (TestSavedMessageServiceSave)
Tests saving messages:
- Save message for Account user
- Save message for EndUser
- Save without user (no-op)
- Prevent duplicate saves (idempotent)
- Message validation through MessageService
### 3. Delete Operations (TestSavedMessageServiceDelete)
Tests deleting saved messages:
- Delete saved message for Account user
- Delete saved message for EndUser
- Delete without user (no-op)
- Delete non-existent saved message (no-op)
- Proper database cleanup
## Testing Approach
- **Mocking Strategy**: All external dependencies (database, MessageService) are mocked
for fast, isolated unit tests
- **Factory Pattern**: SavedMessageServiceTestDataFactory provides consistent test data
- **Fixtures**: Mock objects are configured per test method
- **Assertions**: Each test verifies return values and side effects
(database operations, method calls)
## Key Concepts
**User Types:**
- Account: Workspace members (console users)
- EndUser: API users (end users)
**Saved Messages:**
- Users can save messages for later reference
- Each user has their own saved message list
- Saving is idempotent (duplicate saves ignored)
- Deletion is safe (non-existent deletes ignored)
"""
from datetime import UTC, datetime
from unittest.mock import MagicMock, Mock, create_autospec, patch
import pytest
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.model import App, EndUser, Message
from models.web import SavedMessage
from services.saved_message_service import SavedMessageService
class SavedMessageServiceTestDataFactory:
"""
Factory for creating test data and mock objects.
Provides reusable methods to create consistent mock objects for testing
saved message operations.
"""
@staticmethod
def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock:
"""
Create a mock Account object.
Args:
account_id: Unique identifier for the account
**kwargs: Additional attributes to set on the mock
Returns:
Mock Account object with specified attributes
"""
account = create_autospec(Account, instance=True)
account.id = account_id
for key, value in kwargs.items():
setattr(account, key, value)
return account
@staticmethod
def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock:
"""
Create a mock EndUser object.
Args:
user_id: Unique identifier for the end user
**kwargs: Additional attributes to set on the mock
Returns:
Mock EndUser object with specified attributes
"""
user = create_autospec(EndUser, instance=True)
user.id = user_id
for key, value in kwargs.items():
setattr(user, key, value)
return user
@staticmethod
def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock:
"""
Create a mock App object.
Args:
app_id: Unique identifier for the app
tenant_id: Tenant/workspace identifier
**kwargs: Additional attributes to set on the mock
Returns:
Mock App object with specified attributes
"""
app = create_autospec(App, instance=True)
app.id = app_id
app.tenant_id = tenant_id
app.name = kwargs.get("name", "Test App")
app.mode = kwargs.get("mode", "chat")
for key, value in kwargs.items():
setattr(app, key, value)
return app
@staticmethod
def create_message_mock(
message_id: str = "msg-123",
app_id: str = "app-123",
**kwargs,
) -> Mock:
"""
Create a mock Message object.
Args:
message_id: Unique identifier for the message
app_id: Associated app identifier
**kwargs: Additional attributes to set on the mock
Returns:
Mock Message object with specified attributes
"""
message = create_autospec(Message, instance=True)
message.id = message_id
message.app_id = app_id
message.query = kwargs.get("query", "Test query")
message.answer = kwargs.get("answer", "Test answer")
message.created_at = kwargs.get("created_at", datetime.now(UTC))
for key, value in kwargs.items():
setattr(message, key, value)
return message
@staticmethod
def create_saved_message_mock(
saved_message_id: str = "saved-123",
app_id: str = "app-123",
message_id: str = "msg-123",
created_by: str = "user-123",
created_by_role: str = "account",
**kwargs,
) -> Mock:
"""
Create a mock SavedMessage object.
Args:
saved_message_id: Unique identifier for the saved message
app_id: Associated app identifier
message_id: Associated message identifier
created_by: User who saved the message
created_by_role: Role of the user ('account' or 'end_user')
**kwargs: Additional attributes to set on the mock
Returns:
Mock SavedMessage object with specified attributes
"""
saved_message = create_autospec(SavedMessage, instance=True)
saved_message.id = saved_message_id
saved_message.app_id = app_id
saved_message.message_id = message_id
saved_message.created_by = created_by
saved_message.created_by_role = created_by_role
saved_message.created_at = kwargs.get("created_at", datetime.now(UTC))
for key, value in kwargs.items():
setattr(saved_message, key, value)
return saved_message
@pytest.fixture
def factory():
"""Provide the test data factory to all tests."""
return SavedMessageServiceTestDataFactory
class TestSavedMessageServicePagination:
"""Test saved message pagination operations."""
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with an Account user."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
# Create saved messages for this user
saved_messages = [
factory.create_saved_message_mock(
saved_message_id=f"saved-{i}",
app_id=app.id,
message_id=f"msg-{i}",
created_by=user.id,
created_by_role="account",
)
for i in range(3)
]
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = saved_messages
# Mock MessageService pagination response
expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False)
mock_message_pagination.return_value = expected_pagination
# Act
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20)
# Assert
assert result == expected_pagination
mock_db_session.query.assert_called_once_with(SavedMessage)
# Verify MessageService was called with correct message IDs
mock_message_pagination.assert_called_once_with(
app_model=app,
user=user,
last_id=None,
limit=20,
include_ids=["msg-0", "msg-1", "msg-2"],
)
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with an EndUser."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
# Create saved messages for this end user
saved_messages = [
factory.create_saved_message_mock(
saved_message_id=f"saved-{i}",
app_id=app.id,
message_id=f"msg-{i}",
created_by=user.id,
created_by_role="end_user",
)
for i in range(2)
]
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = saved_messages
# Mock MessageService pagination response
expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=False)
mock_message_pagination.return_value = expected_pagination
# Act
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=10)
# Assert
assert result == expected_pagination
# Verify correct role was used in query
mock_message_pagination.assert_called_once_with(
app_model=app,
user=user,
last_id=None,
limit=10,
include_ids=["msg-0", "msg-1"],
)
def test_pagination_without_user_raises_error(self, factory):
"""Test that pagination without user raises ValueError."""
# Arrange
app = factory.create_app_mock()
# Act & Assert
with pytest.raises(ValueError, match="User is required"):
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20)
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with last_id parameter."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
last_id = "msg-last"
saved_messages = [
factory.create_saved_message_mock(
message_id=f"msg-{i}",
app_id=app.id,
created_by=user.id,
)
for i in range(5)
]
# Mock database query
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = saved_messages
# Mock MessageService pagination response
expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=True)
mock_message_pagination.return_value = expected_pagination
# Act
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=last_id, limit=10)
# Assert
assert result == expected_pagination
# Verify last_id was passed to MessageService
mock_message_pagination.assert_called_once()
call_args = mock_message_pagination.call_args
assert call_args.kwargs["last_id"] == last_id
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination when user has no saved messages."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
# Mock database query returning empty list
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.all.return_value = []
# Mock MessageService pagination response
expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False)
mock_message_pagination.return_value = expected_pagination
# Act
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20)
# Assert
assert result == expected_pagination
# Verify MessageService was called with empty include_ids
mock_message_pagination.assert_called_once_with(
app_model=app,
user=user,
last_id=None,
limit=20,
include_ids=[],
)
class TestSavedMessageServiceSave:
"""Test save message operations."""
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
def test_save_message_for_account(self, mock_db_session, mock_get_message, factory):
"""Test saving a message for an Account user."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
message = factory.create_message_mock(message_id="msg-123", app_id=app.id)
# Mock database query - no existing saved message
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Mock MessageService.get_message
mock_get_message.return_value = message
# Act
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
# Assert
mock_db_session.add.assert_called_once()
saved_message = mock_db_session.add.call_args[0][0]
assert saved_message.app_id == app.id
assert saved_message.message_id == message.id
assert saved_message.created_by == user.id
assert saved_message.created_by_role == "account"
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory):
"""Test saving a message for an EndUser."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
message = factory.create_message_mock(message_id="msg-456", app_id=app.id)
# Mock database query - no existing saved message
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Mock MessageService.get_message
mock_get_message.return_value = message
# Act
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
# Assert
mock_db_session.add.assert_called_once()
saved_message = mock_db_session.add.call_args[0][0]
assert saved_message.app_id == app.id
assert saved_message.message_id == message.id
assert saved_message.created_by == user.id
assert saved_message.created_by_role == "end_user"
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
def test_save_without_user_does_nothing(self, mock_db_session, factory):
"""Test that saving without user is a no-op."""
# Arrange
app = factory.create_app_mock()
# Act
SavedMessageService.save(app_model=app, user=None, message_id="msg-123")
# Assert
mock_db_session.query.assert_not_called()
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory):
"""Test that saving an already saved message is idempotent."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
message_id = "msg-789"
# Mock database query - existing saved message found
existing_saved = factory.create_saved_message_mock(
app_id=app.id,
message_id=message_id,
created_by=user.id,
created_by_role="account",
)
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = existing_saved
# Act
SavedMessageService.save(app_model=app, user=user, message_id=message_id)
# Assert - no new saved message created
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
mock_get_message.assert_not_called()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory):
"""Test that save validates message exists through MessageService."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
message = factory.create_message_mock()
# Mock database query - no existing saved message
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Mock MessageService.get_message
mock_get_message.return_value = message
# Act
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
# Assert - MessageService.get_message was called for validation
mock_get_message.assert_called_once_with(app_model=app, user=user, message_id=message.id)
class TestSavedMessageServiceDelete:
"""Test delete saved message operations."""
@patch("services.saved_message_service.db.session")
def test_delete_saved_message_for_account(self, mock_db_session, factory):
"""Test deleting a saved message for an Account user."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
message_id = "msg-123"
# Mock database query - existing saved message found
saved_message = factory.create_saved_message_mock(
app_id=app.id,
message_id=message_id,
created_by=user.id,
created_by_role="account",
)
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = saved_message
# Act
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
# Assert
mock_db_session.delete.assert_called_once_with(saved_message)
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
def test_delete_saved_message_for_end_user(self, mock_db_session, factory):
"""Test deleting a saved message for an EndUser."""
# Arrange
app = factory.create_app_mock()
user = factory.create_end_user_mock()
message_id = "msg-456"
# Mock database query - existing saved message found
saved_message = factory.create_saved_message_mock(
app_id=app.id,
message_id=message_id,
created_by=user.id,
created_by_role="end_user",
)
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = saved_message
# Act
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
# Assert
mock_db_session.delete.assert_called_once_with(saved_message)
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
def test_delete_without_user_does_nothing(self, mock_db_session, factory):
"""Test that deleting without user is a no-op."""
# Arrange
app = factory.create_app_mock()
# Act
SavedMessageService.delete(app_model=app, user=None, message_id="msg-123")
# Assert
mock_db_session.query.assert_not_called()
mock_db_session.delete.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.db.session")
def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory):
"""Test that deleting a non-existent saved message is a no-op."""
# Arrange
app = factory.create_app_mock()
user = factory.create_account_mock()
message_id = "msg-nonexistent"
# Mock database query - no saved message found
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
# Act
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
# Assert - no deletion occurred
mock_db_session.delete.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.db.session")
def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory):
"""Test that delete only removes the user's own saved message."""
# Arrange
app = factory.create_app_mock()
user1 = factory.create_account_mock(account_id="user-1")
message_id = "msg-shared"
# Mock database query - finds user1's saved message
saved_message = factory.create_saved_message_mock(
app_id=app.id,
message_id=message_id,
created_by=user1.id,
created_by_role="account",
)
mock_query = MagicMock()
mock_db_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = saved_message
# Act
SavedMessageService.delete(app_model=app, user=user1, message_id=message_id)
# Assert - only user1's saved message is deleted
mock_db_session.delete.assert_called_once_with(saved_message)
# Verify the query filters by user
assert mock_query.where.called

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,9 @@
from unittest.mock import Mock
from core.tools.__base.tool import Tool
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.api_entities import ToolApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import ToolParameter
from services.tools.tools_transform_service import ToolTransformService
@ -299,154 +299,3 @@ class TestToolTransformService:
param2 = result.parameters[1]
assert param2.name == "param2"
assert param2.label == "Runtime Param 2"
class TestWorkflowProviderToUserProvider:
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
workflow_app_id = "app_123"
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["label1", "label2"],
workflow_app_id=workflow_app_id,
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.author == "test_author"
assert result.name == "test_workflow_tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == workflow_app_id
assert result.labels == ["label1", "label2"]
assert result.is_team_authorization is True
assert result.plugin_id is None
assert result.plugin_unique_identifier is None
assert result.tools == []
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method without workflow_app_id
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["label1"],
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.workflow_app_id is None
assert result.labels == ["label1"]
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller
provider_id = "provider_123"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "test_author"
mock_controller.entity.identity.name = "test_workflow_tool"
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.icon_dark = None
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
# Call the method with explicit None values
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=None,
workflow_app_id=None,
)
# Verify the result
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.workflow_app_id is None
assert result.labels == []
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
# Create mock workflow tool provider controller with various fields
workflow_app_id = "app_456"
provider_id = "provider_456"
mock_controller = Mock(spec=WorkflowToolProviderController)
mock_controller.provider_id = provider_id
mock_controller.entity = Mock()
mock_controller.entity.identity = Mock()
mock_controller.entity.identity.author = "another_author"
mock_controller.entity.identity.name = "another_workflow_tool"
mock_controller.entity.identity.description = I18nObject(
en_US="Another description", zh_Hans="Another description"
)
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
mock_controller.entity.identity.label = I18nObject(
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
)
# Call the method
result = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=mock_controller,
labels=["automation", "workflow"],
workflow_app_id=workflow_app_id,
)
# Verify all fields are preserved correctly
assert isinstance(result, ToolProviderApiEntity)
assert result.id == provider_id
assert result.author == "another_author"
assert result.name == "another_workflow_tool"
assert result.description.en_US == "Another description"
assert result.description.zh_Hans == "Another description"
assert result.icon == {"type": "emoji", "content": "⚙️"}
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
assert result.label.en_US == "Another Workflow Tool"
assert result.label.zh_Hans == "Another Workflow Tool"
assert result.type == ToolProviderType.WORKFLOW
assert result.workflow_app_id == workflow_app_id
assert result.labels == ["automation", "workflow"]
assert result.masked_credentials == {}
assert result.is_team_authorization is True
assert result.allow_delete is True
assert result.plugin_id is None
assert result.plugin_unique_identifier is None
assert result.tools == []

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

4630
api/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -233,7 +233,7 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false
# Database type, supported values are `postgresql` and `mysql`
DB_TYPE=postgresql
# For MySQL, only `root` user is supported for now
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=db_postgres
@ -1076,10 +1076,24 @@ MAX_TREE_DEPTH=50
# ------------------------------
# Environment Variables for database Service
# ------------------------------
# The name of the default postgres user.
POSTGRES_USER=${DB_USERNAME}
# The password for the default postgres user.
POSTGRES_PASSWORD=${DB_PASSWORD}
# The name of the default postgres database.
POSTGRES_DB=${DB_DATABASE}
# Postgres data directory
PGDATA=/var/lib/postgresql/data/pgdata
# MySQL Default Configuration
# The name of the default mysql user.
MYSQL_USERNAME=${DB_USERNAME}
# The password for the default mysql user.
MYSQL_PASSWORD=${DB_PASSWORD}
# The name of the default mysql database.
MYSQL_DATABASE=${DB_DATABASE}
# MySQL data directory
MYSQL_HOST_VOLUME=./volumes/mysql/data
# ------------------------------

View File

@ -139,9 +139,9 @@ services:
- postgresql
restart: always
environment:
POSTGRES_USER: ${DB_USERNAME:-postgres}
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
POSTGRES_DB: ${DB_DATABASE:-dify}
POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
POSTGRES_DB: ${POSTGRES_DB:-dify}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
command: >
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
@ -161,7 +161,7 @@ services:
"-h",
"db_postgres",
"-U",
"${DB_USERNAME:-postgres}",
"${PGUSER:-postgres}",
"-d",
"${DB_DATABASE:-dify}",
]
@ -176,8 +176,8 @@ services:
- mysql
restart: always
environment:
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${DB_DATABASE:-dify}
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
command: >
--max_connections=1000
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
@ -193,7 +193,7 @@ services:
"ping",
"-u",
"root",
"-p${DB_PASSWORD:-difyai123456}",
"-p${MYSQL_PASSWORD:-difyai123456}",
]
interval: 1s
timeout: 3s

View File

@ -9,8 +9,8 @@ services:
env_file:
- ./middleware.env
environment:
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
POSTGRES_DB: ${DB_DATABASE:-dify}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
POSTGRES_DB: ${POSTGRES_DB:-dify}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
command: >
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
@ -32,9 +32,9 @@ services:
"-h",
"db_postgres",
"-U",
"${DB_USERNAME:-postgres}",
"${PGUSER:-postgres}",
"-d",
"${DB_DATABASE:-dify}",
"${POSTGRES_DB:-dify}",
]
interval: 1s
timeout: 3s
@ -48,8 +48,8 @@ services:
env_file:
- ./middleware.env
environment:
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${DB_DATABASE:-dify}
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
command: >
--max_connections=1000
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
@ -67,7 +67,7 @@ services:
"ping",
"-u",
"root",
"-p${DB_PASSWORD:-difyai123456}",
"-p${MYSQL_PASSWORD:-difyai123456}",
]
interval: 1s
timeout: 3s

View File

@ -455,7 +455,13 @@ x-shared-env: &shared-api-worker-env
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}}
POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
MYSQL_USERNAME: ${MYSQL_USERNAME:-${DB_USERNAME}}
MYSQL_PASSWORD: ${MYSQL_PASSWORD:-${DB_PASSWORD}}
MYSQL_DATABASE: ${MYSQL_DATABASE:-${DB_DATABASE}}
MYSQL_HOST_VOLUME: ${MYSQL_HOST_VOLUME:-./volumes/mysql/data}
SANDBOX_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox}
SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release}
@ -768,9 +774,9 @@ services:
- postgresql
restart: always
environment:
POSTGRES_USER: ${DB_USERNAME:-postgres}
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
POSTGRES_DB: ${DB_DATABASE:-dify}
POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
POSTGRES_DB: ${POSTGRES_DB:-dify}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
command: >
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
@ -790,7 +796,7 @@ services:
"-h",
"db_postgres",
"-U",
"${DB_USERNAME:-postgres}",
"${PGUSER:-postgres}",
"-d",
"${DB_DATABASE:-dify}",
]
@ -805,8 +811,8 @@ services:
- mysql
restart: always
environment:
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${DB_DATABASE:-dify}
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
command: >
--max_connections=1000
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
@ -822,7 +828,7 @@ services:
"ping",
"-u",
"root",
"-p${DB_PASSWORD:-difyai123456}",
"-p${MYSQL_PASSWORD:-difyai123456}",
]
interval: 1s
timeout: 3s

View File

@ -4,7 +4,6 @@
# Database Configuration
# Database type, supported values are `postgresql` and `mysql`
DB_TYPE=postgresql
# For MySQL, only `root` user is supported for now
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=db_postgres
@ -12,6 +11,11 @@ DB_PORT=5432
DB_DATABASE=dify
# PostgreSQL Configuration
POSTGRES_USER=${DB_USERNAME}
# The password for the default postgres user.
POSTGRES_PASSWORD=${DB_PASSWORD}
# The name of the default postgres database.
POSTGRES_DB=${DB_DATABASE}
# postgres data directory
PGDATA=/var/lib/postgresql/data/pgdata
PGDATA_HOST_VOLUME=./volumes/db/data
@ -61,6 +65,11 @@ POSTGRES_STATEMENT_TIMEOUT=0
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
# MySQL Configuration
MYSQL_USERNAME=${DB_USERNAME}
# MySQL password
MYSQL_PASSWORD=${DB_PASSWORD}
# MySQL database name
MYSQL_DATABASE=${DB_DATABASE}
# MySQL data directory host volume
MYSQL_HOST_VOLUME=./volumes/mysql/data

View File

@ -3,7 +3,6 @@ import type { ReactNode } from 'react'
import SwrInitializer from '@/app/components/swr-initializer'
import { AppContextProvider } from '@/context/app-context'
import GA, { GaType } from '@/app/components/base/ga'
import AmplitudeProvider from '@/app/components/base/amplitude'
import HeaderWrapper from '@/app/components/header/header-wrapper'
import Header from '@/app/components/header'
import { EventEmitterContextProvider } from '@/context/event-emitter'
@ -19,7 +18,6 @@ const Layout = ({ children }: { children: ReactNode }) => {
return (
<>
<GA gaType={GaType.admin} />
<AmplitudeProvider />
<SwrInitializer>
<AppContextProvider>
<EventEmitterContextProvider>

View File

@ -8,7 +8,7 @@ const PluginList = async () => {
return (
<PluginPage
plugins={<PluginsPanel />}
marketplace={<Marketplace locale={locale} pluginTypeSwitchClassName='top-[60px]' showSearchParams={false} />}
marketplace={<Marketplace locale={locale} pluginTypeSwitchClassName='top-[60px]' searchBoxAutoAnimate={false} showSearchParams={false} />}
/>
)
}

View File

@ -1,5 +1,6 @@
'use client'
import { useState } from 'react'
import useSWR from 'swr'
import { useTranslation } from 'react-i18next'
import {
RiGraduationCapFill,
@ -22,9 +23,8 @@ import PremiumBadge from '@/app/components/base/premium-badge'
import { useGlobalPublicStore } from '@/context/global-public-context'
import EmailChangeModal from './email-change-modal'
import { validPassword } from '@/config'
import { fetchAppList } from '@/service/apps'
import type { App } from '@/types/app'
import { useAppList } from '@/service/use-apps'
const titleClassName = `
system-sm-semibold text-text-secondary
@ -36,7 +36,7 @@ const descriptionClassName = `
export default function AccountPage() {
const { t } = useTranslation()
const { systemFeatures } = useGlobalPublicStore()
const { data: appList } = useAppList({ page: 1, limit: 100, name: '' })
const { data: appList } = useSWR({ url: '/apps', params: { page: 1, limit: 100, name: '' } }, fetchAppList)
const apps = appList?.data || []
const { mutateUserProfile, userProfile } = useAppContext()
const { isEducationAccount } = useProviderContext()

View File

@ -12,7 +12,6 @@ import { useProviderContext } from '@/context/provider-context'
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
import PremiumBadge from '@/app/components/base/premium-badge'
import { useLogout } from '@/service/use-common'
import { resetUser } from '@/app/components/base/amplitude/utils'
export type IAppSelector = {
isMobile: boolean
@ -29,7 +28,6 @@ export default function AppSelector() {
await logout()
localStorage.removeItem('setup_status')
resetUser()
// Tokens are now stored in cookies and cleared by backend
router.push('/signin')

View File

@ -4,7 +4,6 @@ import Header from './header'
import SwrInitor from '@/app/components/swr-initializer'
import { AppContextProvider } from '@/context/app-context'
import GA, { GaType } from '@/app/components/base/ga'
import AmplitudeProvider from '@/app/components/base/amplitude'
import HeaderWrapper from '@/app/components/header/header-wrapper'
import { EventEmitterContextProvider } from '@/context/event-emitter'
import { ProviderContextProvider } from '@/context/provider-context'
@ -14,7 +13,6 @@ const Layout = ({ children }: { children: ReactNode }) => {
return (
<>
<GA gaType={GaType.admin} />
<AmplitudeProvider />
<SwrInitor>
<AppContextProvider>
<EventEmitterContextProvider>

View File

@ -1,6 +1,6 @@
'use client'
import type { FC, PropsWithChildren } from 'react'
import useAccessControlStore from '@/context/access-control-store'
import useAccessControlStore from '../../../../context/access-control-store'
import type { AccessMode } from '@/models/access-control'
type AccessControlItemProps = PropsWithChildren<{
@ -8,8 +8,7 @@ type AccessControlItemProps = PropsWithChildren<{
}>
const AccessControlItem: FC<AccessControlItemProps> = ({ type, children }) => {
const currentMenu = useAccessControlStore(s => s.currentMenu)
const setCurrentMenu = useAccessControlStore(s => s.setCurrentMenu)
const { currentMenu, setCurrentMenu } = useAccessControlStore(s => ({ currentMenu: s.currentMenu, setCurrentMenu: s.setCurrentMenu }))
if (currentMenu !== type) {
return <div
className="cursor-pointer rounded-[10px] border-[1px]

View File

@ -251,7 +251,6 @@ const AgentTools: FC = () => {
{!item.notAuthor && (
<Tooltip
popupContent={t('tools.setBuiltInTools.infoAndSetting')}
needsDelay={false}
>
<div className='cursor-pointer rounded-md p-1 hover:bg-black/5' onClick={() => {
setCurrentTool(item)

View File

@ -28,7 +28,6 @@ import Input from '@/app/components/base/input'
import { AppModeEnum } from '@/types/app'
import { DSLImportMode } from '@/models/app'
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
import { trackEvent } from '@/app/components/base/amplitude'
type AppsProps = {
onSuccess?: () => void
@ -142,15 +141,6 @@ const Apps = ({
icon_background,
description,
})
// Track app creation from template
trackEvent('create_app_with_template', {
app_mode: mode,
template_id: currApp?.app.id,
template_name: currApp?.app.name,
description,
})
setIsShowCreateModal(false)
Toast.notify({
type: 'success',

View File

@ -30,7 +30,6 @@ import { getRedirection } from '@/utils/app-redirection'
import FullScreenModal from '@/app/components/base/fullscreen-modal'
import useTheme from '@/hooks/use-theme'
import { useDocLink } from '@/context/i18n'
import { trackEvent } from '@/app/components/base/amplitude'
type CreateAppProps = {
onSuccess: () => void
@ -83,13 +82,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined,
mode: appMode,
})
// Track app creation success
trackEvent('create_app', {
app_mode: appMode,
description,
})
notify({ type: 'success', message: t('app.newApp.appCreated') })
onSuccess()
onClose()

View File

@ -28,7 +28,6 @@ import { getRedirection } from '@/utils/app-redirection'
import cn from '@/utils/classnames'
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
import { noop } from 'lodash-es'
import { trackEvent } from '@/app/components/base/amplitude'
type CreateFromDSLModalProps = {
show: boolean
@ -113,13 +112,6 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
return
const { id, status, app_id, app_mode, imported_dsl_version, current_dsl_version } = response
if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) {
// Track app creation from DSL import
trackEvent('create_app_with_dsl', {
app_mode,
creation_method: currentTab === CreateFromDSLModalTab.FROM_FILE ? 'dsl_file' : 'dsl_url',
has_warnings: status === DSLImportStatus.COMPLETED_WITH_WARNINGS,
})
if (onSuccess)
onSuccess()
if (onClose)

View File

@ -3,6 +3,7 @@ import type { FC } from 'react'
import React from 'react'
import ReactECharts from 'echarts-for-react'
import type { EChartsOption } from 'echarts'
import useSWR from 'swr'
import type { Dayjs } from 'dayjs'
import dayjs from 'dayjs'
import { get } from 'lodash-es'
@ -12,20 +13,7 @@ import { formatNumber } from '@/utils/format'
import Basic from '@/app/components/app-sidebar/basic'
import Loading from '@/app/components/base/loading'
import type { AppDailyConversationsResponse, AppDailyEndUsersResponse, AppDailyMessagesResponse, AppTokenCostsResponse } from '@/models/app'
import {
useAppAverageResponseTime,
useAppAverageSessionInteractions,
useAppDailyConversations,
useAppDailyEndUsers,
useAppDailyMessages,
useAppSatisfactionRate,
useAppTokenCosts,
useAppTokensPerSecond,
useWorkflowAverageInteractions,
useWorkflowDailyConversations,
useWorkflowDailyTerminals,
useWorkflowTokenCosts,
} from '@/service/use-apps'
import { getAppDailyConversations, getAppDailyEndUsers, getAppDailyMessages, getAppStatistics, getAppTokenCosts, getWorkflowDailyConversations } from '@/service/apps'
const valueFormatter = (v: string | number) => v
const COLOR_TYPE_MAP = {
@ -284,8 +272,8 @@ const getDefaultChartData = ({ start, end, key = 'count' }: { start: string; end
export const MessagesChart: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useAppDailyMessages(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-messages`, params: period.query }, getAppDailyMessages)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -298,8 +286,8 @@ export const MessagesChart: FC<IBizChartProps> = ({ id, period }) => {
export const ConversationsChart: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useAppDailyConversations(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-conversations`, params: period.query }, getAppDailyConversations)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -313,8 +301,8 @@ export const ConversationsChart: FC<IBizChartProps> = ({ id, period }) => {
export const EndUsersChart: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useAppDailyEndUsers(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-end-users`, id, params: period.query }, getAppDailyEndUsers)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -327,8 +315,8 @@ export const EndUsersChart: FC<IBizChartProps> = ({ id, period }) => {
export const AvgSessionInteractions: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useAppAverageSessionInteractions(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/statistics/average-session-interactions`, params: period.query }, getAppStatistics)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -343,8 +331,8 @@ export const AvgSessionInteractions: FC<IBizChartProps> = ({ id, period }) => {
export const AvgResponseTime: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useAppAverageResponseTime(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/statistics/average-response-time`, params: period.query }, getAppStatistics)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -360,8 +348,8 @@ export const AvgResponseTime: FC<IBizChartProps> = ({ id, period }) => {
export const TokenPerSecond: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useAppTokensPerSecond(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/statistics/tokens-per-second`, params: period.query }, getAppStatistics)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -378,8 +366,8 @@ export const TokenPerSecond: FC<IBizChartProps> = ({ id, period }) => {
export const UserSatisfactionRate: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useAppSatisfactionRate(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/statistics/user-satisfaction-rate`, params: period.query }, getAppStatistics)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -396,8 +384,8 @@ export const UserSatisfactionRate: FC<IBizChartProps> = ({ id, period }) => {
export const CostChart: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useAppTokenCosts(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/statistics/token-costs`, params: period.query }, getAppTokenCosts)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -410,8 +398,8 @@ export const CostChart: FC<IBizChartProps> = ({ id, period }) => {
export const WorkflowMessagesChart: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useWorkflowDailyConversations(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/workflow/statistics/daily-conversations`, params: period.query }, getWorkflowDailyConversations)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -426,8 +414,8 @@ export const WorkflowMessagesChart: FC<IBizChartProps> = ({ id, period }) => {
export const WorkflowDailyTerminalsChart: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useWorkflowDailyTerminals(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/workflow/statistics/daily-terminals`, id, params: period.query }, getAppDailyEndUsers)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -441,8 +429,8 @@ export const WorkflowDailyTerminalsChart: FC<IBizChartProps> = ({ id, period })
export const WorkflowCostChart: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useWorkflowTokenCosts(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/workflow/statistics/token-costs`, params: period.query }, getAppTokenCosts)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart
@ -455,8 +443,8 @@ export const WorkflowCostChart: FC<IBizChartProps> = ({ id, period }) => {
export const AvgUserInteractions: FC<IBizChartProps> = ({ id, period }) => {
const { t } = useTranslation()
const { data: response, isLoading } = useWorkflowAverageInteractions(id, period.query)
if (isLoading || !response)
const { data: response } = useSWR({ url: `/apps/${id}/workflow/statistics/average-app-interactions`, params: period.query }, getAppStatistics)
if (!response)
return <Loading />
const noDataFlag = !response.data || response.data.length === 0
return <Chart

View File

@ -8,7 +8,6 @@ import quarterOfYear from 'dayjs/plugin/quarterOfYear'
import type { QueryParam } from './index'
import Chip from '@/app/components/base/chip'
import Input from '@/app/components/base/input'
import { trackEvent } from '@/app/components/base/amplitude/utils'
dayjs.extend(quarterOfYear)
const today = dayjs()
@ -38,9 +37,6 @@ const Filter: FC<IFilterProps> = ({ queryParams, setQueryParams }: IFilterProps)
value={queryParams.status || 'all'}
onSelect={(item) => {
setQueryParams({ ...queryParams, status: item.value as string })
trackEvent('workflow_log_filter_status_selected', {
workflow_log_filter_status: item.value as string,
})
}}
onClear={() => setQueryParams({ ...queryParams, status: 'all' })}
items={[{ value: 'all', name: 'All' },

View File

@ -23,7 +23,7 @@ const Empty = () => {
return (
<>
<DefaultCards />
<div className='pointer-events-none absolute inset-0 z-20 flex items-center justify-center bg-gradient-to-t from-background-body to-transparent'>
<div className='absolute bottom-0 left-0 right-0 top-0 flex items-center justify-center bg-gradient-to-t from-background-body to-transparent'>
<span className='system-md-medium text-text-tertiary'>
{t('app.newApp.noAppsFound')}
</span>

View File

@ -4,6 +4,7 @@ import { useCallback, useEffect, useRef, useState } from 'react'
import {
useRouter,
} from 'next/navigation'
import useSWRInfinite from 'swr/infinite'
import { useTranslation } from 'react-i18next'
import { useDebounceFn } from 'ahooks'
import {
@ -18,6 +19,8 @@ import AppCard from './app-card'
import NewAppCard from './new-app-card'
import useAppsQueryState from './hooks/use-apps-query-state'
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
import type { AppListResponse } from '@/models/app'
import { fetchAppList } from '@/service/apps'
import { useAppContext } from '@/context/app-context'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { CheckModal } from '@/hooks/use-pay'
@ -32,7 +35,6 @@ import Empty from './empty'
import Footer from './footer'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { AppModeEnum } from '@/types/app'
import { useInfiniteAppList } from '@/service/use-apps'
const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), {
ssr: false,
@ -41,6 +43,30 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro
ssr: false,
})
const getKey = (
pageIndex: number,
previousPageData: AppListResponse,
activeTab: string,
isCreatedByMe: boolean,
tags: string[],
keywords: string,
) => {
if (!pageIndex || previousPageData.has_more) {
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords, is_created_by_me: isCreatedByMe } }
if (activeTab !== 'all')
params.params.mode = activeTab
else
delete params.params.mode
if (tags.length)
params.params.tag_ids = tags
return params
}
return null
}
const List = () => {
const { t } = useTranslation()
const { systemFeatures } = useGlobalPublicStore()
@ -76,24 +102,16 @@ const List = () => {
enabled: isCurrentWorkspaceEditor,
})
const appListQueryParams = {
page: 1,
limit: 30,
name: searchKeywords,
tag_ids: tagIDs,
is_created_by_me: isCreatedByMe,
...(activeTab !== 'all' ? { mode: activeTab as AppModeEnum } : {}),
}
const {
data,
isLoading,
isFetchingNextPage,
fetchNextPage,
hasNextPage,
error,
refetch,
} = useInfiniteAppList(appListQueryParams, { enabled: !isCurrentWorkspaceDatasetOperator })
const { data, isLoading, error, setSize, mutate } = useSWRInfinite(
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords),
fetchAppList,
{
revalidateFirstPage: true,
shouldRetryOnError: false,
dedupingInterval: 500,
errorRetryCount: 3,
},
)
const anchorRef = useRef<HTMLDivElement>(null)
const options = [
@ -108,9 +126,9 @@ const List = () => {
useEffect(() => {
if (localStorage.getItem(NEED_REFRESH_APP_LIST_KEY) === '1') {
localStorage.removeItem(NEED_REFRESH_APP_LIST_KEY)
refetch()
mutate()
}
}, [refetch])
}, [mutate, t])
useEffect(() => {
if (isCurrentWorkspaceDatasetOperator)
@ -118,9 +136,7 @@ const List = () => {
}, [router, isCurrentWorkspaceDatasetOperator])
useEffect(() => {
if (isCurrentWorkspaceDatasetOperator)
return
const hasMore = hasNextPage ?? true
const hasMore = data?.at(-1)?.has_more ?? true
let observer: IntersectionObserver | undefined
if (error) {
@ -135,8 +151,8 @@ const List = () => {
const dynamicMargin = Math.max(100, Math.min(containerHeight * 0.2, 200)) // Clamps to 100-200px range, using 20% of container height as the base value
observer = new IntersectionObserver((entries) => {
if (entries[0].isIntersecting && !isLoading && !isFetchingNextPage && !error && hasMore)
fetchNextPage()
if (entries[0].isIntersecting && !isLoading && !error && hasMore)
setSize((size: number) => size + 1)
}, {
root: containerRef.current,
rootMargin: `${dynamicMargin}px`,
@ -145,7 +161,7 @@ const List = () => {
observer.observe(anchorRef.current)
}
return () => observer?.disconnect()
}, [isLoading, isFetchingNextPage, fetchNextPage, error, hasNextPage, isCurrentWorkspaceDatasetOperator])
}, [isLoading, setSize, data, error])
const { run: handleSearch } = useDebounceFn(() => {
setSearchKeywords(keywords)
@ -169,9 +185,6 @@ const List = () => {
setQuery(prev => ({ ...prev, isCreatedByMe: newValue }))
}, [isCreatedByMe, setQuery])
const pages = data?.pages ?? []
const hasAnyApp = (pages[0]?.total ?? 0) > 0
return (
<>
<div ref={containerRef} className='relative flex h-0 shrink-0 grow flex-col overflow-y-auto bg-background-body'>
@ -204,17 +217,17 @@ const List = () => {
/>
</div>
</div>
{hasAnyApp
{(data && data[0].total > 0)
? <div className='relative grid grow grid-cols-1 content-start gap-4 px-12 pt-2 sm:grid-cols-1 md:grid-cols-2 xl:grid-cols-4 2xl:grid-cols-5 2k:grid-cols-6'>
{isCurrentWorkspaceEditor
&& <NewAppCard ref={newAppCardRef} onSuccess={refetch} selectedAppType={activeTab} />}
{pages.map(({ data: apps }) => apps.map(app => (
<AppCard key={app.id} app={app} onRefresh={refetch} />
&& <NewAppCard ref={newAppCardRef} onSuccess={mutate} selectedAppType={activeTab} />}
{data.map(({ data: apps }) => apps.map(app => (
<AppCard key={app.id} app={app} onRefresh={mutate} />
)))}
</div>
: <div className='relative grid grow grid-cols-1 content-start gap-4 overflow-hidden px-12 pt-2 sm:grid-cols-1 md:grid-cols-2 xl:grid-cols-4 2xl:grid-cols-5 2k:grid-cols-6'>
{isCurrentWorkspaceEditor
&& <NewAppCard ref={newAppCardRef} className='z-10' onSuccess={refetch} selectedAppType={activeTab} />}
&& <NewAppCard ref={newAppCardRef} className='z-10' onSuccess={mutate} selectedAppType={activeTab} />}
<Empty />
</div>}
@ -248,7 +261,7 @@ const List = () => {
onSuccess={() => {
setShowCreateFromDSLModal(false)
setDroppedDSLFile(undefined)
refetch()
mutate()
}}
droppedFile={droppedDSLFile}
/>

View File

@ -1,46 +0,0 @@
'use client'
import type { FC } from 'react'
import React, { useEffect } from 'react'
import * as amplitude from '@amplitude/analytics-browser'
import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser'
import { IS_CLOUD_EDITION } from '@/config'
export type IAmplitudeProps = {
apiKey?: string
sessionReplaySampleRate?: number
}
const AmplitudeProvider: FC<IAmplitudeProps> = ({
apiKey = process.env.NEXT_PUBLIC_AMPLITUDE_API_KEY ?? '',
sessionReplaySampleRate = 1,
}) => {
useEffect(() => {
// Only enable in Saas edition
if (!IS_CLOUD_EDITION)
return
// Initialize Amplitude
amplitude.init(apiKey, {
defaultTracking: {
sessions: true,
pageViews: true,
formInteractions: true,
fileDownloads: true,
},
// Enable debug logs in development environment
logLevel: amplitude.Types.LogLevel.Warn,
})
// Add Session Replay plugin
const sessionReplay = sessionReplayPlugin({
sampleRate: sessionReplaySampleRate,
})
amplitude.add(sessionReplay)
}, [])
// This is a client component that renders nothing
return null
}
export default React.memo(AmplitudeProvider)

View File

@ -1,2 +0,0 @@
export { default } from './AmplitudeProvider'
export { resetUser, setUserId, setUserProperties, trackEvent } from './utils'

View File

@ -1,37 +0,0 @@
import * as amplitude from '@amplitude/analytics-browser'
/**
* Track custom event
* @param eventName Event name
* @param eventProperties Event properties (optional)
*/
export const trackEvent = (eventName: string, eventProperties?: Record<string, any>) => {
amplitude.track(eventName, eventProperties)
}
/**
* Set user ID
* @param userId User ID
*/
export const setUserId = (userId: string) => {
amplitude.setUserId(userId)
}
/**
* Set user properties
* @param properties User properties
*/
export const setUserProperties = (properties: Record<string, any>) => {
const identifyEvent = new amplitude.Identify()
Object.entries(properties).forEach(([key, value]) => {
identifyEvent.set(key, value)
})
amplitude.identify(identifyEvent)
}
/**
* Reset user (e.g., when user logs out)
*/
export const resetUser = () => {
amplitude.reset()
}

Some files were not shown because too many files have changed in this diff Show More