mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 03:59:30 +08:00
Compare commits
1 Commits
feat/track
...
fix/codeow
| Author | SHA1 | Date | |
|---|---|---|---|
| 5874b920b2 |
8
.github/CODEOWNERS
vendored
8
.github/CODEOWNERS
vendored
@ -18,10 +18,10 @@ api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||
api/core/workflow/nodes/agent/ @Nov1c444
|
||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
||||
api/core/workflow/nodes/loop/ @Nov1c444
|
||||
api/core/workflow/nodes/llm/ @Nov1c444
|
||||
api/core/workflow/nodes/agent/ Nov1c444
|
||||
api/core/workflow/nodes/iteration/ Nov1c444
|
||||
api/core/workflow/nodes/loop/ Nov1c444
|
||||
api/core/workflow/nodes/llm/ Nov1c444
|
||||
|
||||
# Backend - RAG (Retrieval Augmented Generation)
|
||||
api/core/rag/ @JohnJyong
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from dify_app import DifyApp
|
||||
@ -28,25 +26,8 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
# add an unique identifier to each request
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
|
||||
@dify_app.after_request
|
||||
def add_trace_id_header(response):
|
||||
try:
|
||||
span = get_current_span()
|
||||
ctx = span.get_span_context() if span else None
|
||||
if ctx and ctx.is_valid:
|
||||
trace_id_hex = format(ctx.trace_id, "032x")
|
||||
# Avoid duplicates if some middleware added it
|
||||
if "X-Trace-Id" not in response.headers:
|
||||
response.headers["X-Trace-Id"] = trace_id_hex
|
||||
except Exception:
|
||||
# Never break the response due to tracing header injection
|
||||
logger.warning("Failed to add trace ID to response header", exc_info=True)
|
||||
return response
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
_ = before_request
|
||||
_ = add_trace_id_header
|
||||
|
||||
return dify_app
|
||||
|
||||
@ -70,7 +51,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
ext_import_modules,
|
||||
ext_logging,
|
||||
@ -95,7 +75,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_warnings,
|
||||
ext_import_modules,
|
||||
ext_orjson,
|
||||
ext_forward_refs,
|
||||
ext_set_secretkey,
|
||||
ext_compress,
|
||||
ext_code_based_extension,
|
||||
|
||||
@ -553,10 +553,7 @@ class LoggingConfig(BaseSettings):
|
||||
|
||||
LOG_FORMAT: str = Field(
|
||||
description="Format string for log messages",
|
||||
default=(
|
||||
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
|
||||
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
|
||||
),
|
||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
||||
)
|
||||
|
||||
LOG_DATEFORMAT: str | None = Field(
|
||||
|
||||
@ -1,23 +1,16 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
app_mode: str = Field(..., description="Application mode")
|
||||
model_mode: str = Field(..., description="Model mode")
|
||||
has_context: str = Field(default="true", description="Whether has context")
|
||||
model_name: str = Field(..., description="Model name")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AdvancedPromptTemplateQuery.__name__,
|
||||
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
)
|
||||
|
||||
|
||||
@ -25,7 +18,7 @@ console_ns.schema_model(
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
@console_ns.doc("get_advanced_prompt_templates")
|
||||
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
@ -34,6 +27,6 @@ class AdvancedPromptTemplateList(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = parser.parse_args()
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
from werkzeug.exceptions import BadRequest, abort
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@ -39,130 +36,6 @@ from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
else:
|
||||
raise TypeError("Unsupported tag_ids type.")
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||
|
||||
|
||||
class AppNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Name to check")
|
||||
|
||||
|
||||
class AppIconPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon data")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
|
||||
class AppSiteStatusPayload(BaseModel):
|
||||
enable_site: bool = Field(..., description="Enable or disable site")
|
||||
|
||||
|
||||
class AppApiStatusPayload(BaseModel):
|
||||
enable_api: bool = Field(..., description="Enable or disable API")
|
||||
|
||||
|
||||
class AppTracePayload(BaseModel):
|
||||
enabled: bool = Field(..., description="Enable or disable tracing")
|
||||
tracing_provider: str = Field(..., description="Tracing provider")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AppListQuery)
|
||||
reg(CreateAppPayload)
|
||||
reg(UpdateAppPayload)
|
||||
reg(CopyAppPayload)
|
||||
reg(AppExportQuery)
|
||||
reg(AppNamePayload)
|
||||
reg(AppIconPayload)
|
||||
reg(AppSiteStatusPayload)
|
||||
reg(AppApiStatusPayload)
|
||||
reg(AppTracePayload)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base models first
|
||||
@ -274,7 +147,22 @@ app_pagination_model = console_ns.model(
|
||||
class AppListApi(Resource):
|
||||
@console_ns.doc("list_apps")
|
||||
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
||||
@console_ns.expect(console_ns.models[AppListQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
|
||||
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
|
||||
default="all",
|
||||
help="App mode filter",
|
||||
)
|
||||
.add_argument("name", type=str, location="args", help="Filter by app name")
|
||||
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
|
||||
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
|
||||
)
|
||||
@console_ns.response(200, "Success", app_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -284,12 +172,42 @@ class AppListApi(Resource):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_dict = args.model_dump()
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(",")]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"completion",
|
||||
"chat",
|
||||
"advanced-chat",
|
||||
"workflow",
|
||||
"agent-chat",
|
||||
"channel",
|
||||
"all",
|
||||
],
|
||||
default="all",
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument("name", type=str, location="args", required=False)
|
||||
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
@ -324,13 +242,10 @@ class AppListApi(Resource):
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
try:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
@ -339,7 +254,19 @@ class AppListApi(Resource):
|
||||
|
||||
@console_ns.doc("create_app")
|
||||
@console_ns.doc(description="Create a new application")
|
||||
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "App created successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@ -352,10 +279,22 @@ class AppListApi(Resource):
|
||||
def post(self):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if "mode" not in args or args["mode"] is None:
|
||||
raise BadRequest("mode is required")
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||
app = app_service.create_app(current_tenant_id, args, current_user)
|
||||
|
||||
return app, 201
|
||||
|
||||
@ -387,7 +326,20 @@ class AppApi(Resource):
|
||||
@console_ns.doc("update_app")
|
||||
@console_ns.doc(description="Update application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
"max_active_requests": fields.Integer(description="Maximum active requests"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@ -399,18 +351,28 @@ class AppApi(Resource):
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
.add_argument("max_active_requests", type=int, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args.name,
|
||||
"description": args.description or "",
|
||||
"icon_type": args.icon_type or "",
|
||||
"icon": args.icon or "",
|
||||
"icon_background": args.icon_background or "",
|
||||
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||
"max_active_requests": args.max_active_requests or 0,
|
||||
"name": args["name"],
|
||||
"description": args.get("description", ""),
|
||||
"icon_type": args.get("icon_type", ""),
|
||||
"icon": args.get("icon", ""),
|
||||
"icon_background": args.get("icon_background", ""),
|
||||
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
|
||||
"max_active_requests": args.get("max_active_requests", 0),
|
||||
}
|
||||
app_model = app_service.update_app(app_model, args_dict)
|
||||
|
||||
@ -439,7 +401,18 @@ class AppCopyApi(Resource):
|
||||
@console_ns.doc("copy_app")
|
||||
@console_ns.doc(description="Create a copy of an existing application")
|
||||
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
||||
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CopyAppRequest",
|
||||
{
|
||||
"name": fields.String(description="Name for the copied app"),
|
||||
"description": fields.String(description="Description for the copied app"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -453,7 +426,15 @@ class AppCopyApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
@ -462,11 +443,11 @@ class AppCopyApi(Resource):
|
||||
account=current_user,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@ -481,7 +462,11 @@ class AppExportApi(Resource):
|
||||
@console_ns.doc("export_app")
|
||||
@console_ns.doc(description="Export application configuration as DSL")
|
||||
@console_ns.doc(params={"app_id": "Application ID to export"})
|
||||
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
|
||||
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"App exported successfully",
|
||||
@ -495,23 +480,30 @@ class AppExportApi(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
# Add include_secret params
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
.add_argument("workflow_id", type=str, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model,
|
||||
include_secret=args.include_secret,
|
||||
workflow_id=args.workflow_id,
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
@console_ns.doc(description="Check if app name is available")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -520,10 +512,10 @@ class AppNameApi(Resource):
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args.name)
|
||||
app_model = app_service.update_app_name(app_model, args["name"])
|
||||
|
||||
return app_model
|
||||
|
||||
@ -533,7 +525,16 @@ class AppIconApi(Resource):
|
||||
@console_ns.doc("update_app_icon")
|
||||
@console_ns.doc(description="Update application icon")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppIconRequest",
|
||||
{
|
||||
"icon": fields.String(required=True, description="Icon data"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Icon updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -543,10 +544,15 @@ class AppIconApi(Resource):
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
|
||||
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
|
||||
|
||||
return app_model
|
||||
|
||||
@ -556,7 +562,11 @@ class AppSiteStatus(Resource):
|
||||
@console_ns.doc("update_app_site_status")
|
||||
@console_ns.doc(description="Enable or disable app site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -566,10 +576,11 @@ class AppSiteStatus(Resource):
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args.enable_site)
|
||||
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
|
||||
|
||||
return app_model
|
||||
|
||||
@ -579,7 +590,11 @@ class AppApiStatus(Resource):
|
||||
@console_ns.doc("update_app_api_status")
|
||||
@console_ns.doc(description="Enable or disable app API")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -589,10 +604,11 @@ class AppApiStatus(Resource):
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
def post(self, app_model):
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args.enable_api)
|
||||
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
|
||||
|
||||
return app_model
|
||||
|
||||
@ -615,7 +631,15 @@ class AppTraceApi(Resource):
|
||||
@console_ns.doc("update_app_trace")
|
||||
@console_ns.doc(description="Update app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppTraceRequest",
|
||||
{
|
||||
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Trace configuration updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -624,12 +648,17 @@ class AppTraceApi(Resource):
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
# add app trace
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("enabled", type=bool, required=True, location="json")
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=app_id,
|
||||
enabled=args.enabled,
|
||||
tracing_provider=args.tracing_provider,
|
||||
enabled=args["enabled"],
|
||||
tracing_provider=args["tracing_provider"],
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@ -37,41 +35,6 @@ from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(default="", description="Query text")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(..., description="User query")
|
||||
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionMessagePayload.__name__,
|
||||
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
@ -80,7 +43,19 @@ class CompletionMessageApi(Resource):
|
||||
@console_ns.doc("create_completion_message")
|
||||
@console_ns.doc(description="Generate completion message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CompletionMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(description="Query text", default=""),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Completion generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App not found")
|
||||
@ -89,10 +64,18 @@ class CompletionMessageApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, location="json", default="")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
@ -154,7 +137,21 @@ class ChatMessageApi(Resource):
|
||||
@console_ns.doc("create_chat_message")
|
||||
@console_ns.doc(description="Generate chat message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ChatMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
"parent_message_id": fields.String(description="Parent message ID"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Chat message generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App or conversation not found")
|
||||
@ -164,10 +161,20 @@ class ChatMessageApi(Resource):
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
from typing import Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask import abort
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -16,53 +14,13 @@ from extensions.ext_database import db
|
||||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import DatetimeString, TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseConversationQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
|
||||
default="all", description="Annotation status filter"
|
||||
)
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
class CompletionConversationQuery(BaseConversationQuery):
|
||||
pass
|
||||
|
||||
|
||||
class ChatConversationQuery(BaseConversationQuery):
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort field and direction"
|
||||
)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionConversationQuery.__name__,
|
||||
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatConversationQuery.__name__,
|
||||
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
@ -325,7 +283,22 @@ class CompletionConversationApi(Resource):
|
||||
@console_ns.doc("list_completion_conversations")
|
||||
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
)
|
||||
@console_ns.response(200, "Success", conversation_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -336,17 +309,32 @@ class CompletionConversationApi(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
if args.keyword:
|
||||
if args["keyword"]:
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||
or_(
|
||||
Message.query.ilike(f"%{args.keyword}%"),
|
||||
Message.answer.ilike(f"%{args.keyword}%"),
|
||||
Message.query.ilike(f"%{args['keyword']}%"),
|
||||
Message.answer.ilike(f"%{args['keyword']}%"),
|
||||
)
|
||||
)
|
||||
|
||||
@ -354,7 +342,7 @@ class CompletionConversationApi(Resource):
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -366,11 +354,11 @@ class CompletionConversationApi(Resource):
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
if args.annotation_status == "annotated":
|
||||
if args["annotation_status"] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
@ -379,7 +367,7 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
@ -431,7 +419,31 @@ class ChatConversationApi(Resource):
|
||||
@console_ns.doc("list_chat_conversations")
|
||||
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
default="-updated_at",
|
||||
help="Sort field and direction",
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -442,7 +454,31 @@ class ChatConversationApi(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
@ -454,8 +490,8 @@ class ChatConversationApi(Resource):
|
||||
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args.keyword:
|
||||
keyword_filter = f"%{args.keyword}%"
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
@ -478,12 +514,12 @@ class ChatConversationApi(Resource):
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
match args.sort_by:
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
@ -491,27 +527,35 @@ class ChatConversationApi(Resource):
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args.sort_by:
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args.annotation_status == "annotated":
|
||||
if args["annotation_status"] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
|
||||
if args["message_count_gte"] and args["message_count_gte"] >= 1:
|
||||
query = (
|
||||
query.options(joinedload(Conversation.messages)) # type: ignore
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
match args.sort_by:
|
||||
match args["sort_by"]:
|
||||
case "created_at":
|
||||
query = query.order_by(Conversation.created_at.asc())
|
||||
case "-created_at":
|
||||
@ -523,7 +567,7 @@ class ChatConversationApi(Resource):
|
||||
case _:
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -16,18 +14,6 @@ from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ConversationVariablesQuery.__name__,
|
||||
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
@ -47,7 +33,11 @@ class ConversationVariablesApi(Resource):
|
||||
@console_ns.doc("get_conversation_variables")
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -55,14 +45,18 @@ class ConversationVariablesApi(Resource):
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
def get(self, app_model):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.order_by(ConversationVariable.created_at)
|
||||
)
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
|
||||
if args["conversation_id"]:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
|
||||
else:
|
||||
raise ValueError("conversation_id is required")
|
||||
|
||||
# NOTE: This is a temporary solution to avoid performance issues.
|
||||
page = 1
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@ -23,54 +21,21 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
code_language: str = Field(default="javascript", description="Programming language for code generation")
|
||||
|
||||
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
node_id: str = Field(default="", description="Node ID for workflow context")
|
||||
current: str = Field(default="", description="Current instruction text")
|
||||
language: str = Field(default="javascript", description="Programming language (javascript/python)")
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
|
||||
|
||||
class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(RuleGeneratePayload)
|
||||
reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
class RuleGenerateApi(Resource):
|
||||
@console_ns.doc("generate_rule_config")
|
||||
@console_ns.doc(description="Generate rule configuration using LLM")
|
||||
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"RuleGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Rule generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Rule configuration generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@ -78,15 +43,21 @@ class RuleGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=args.no_variable,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -104,7 +75,19 @@ class RuleGenerateApi(Resource):
|
||||
class RuleCodeGenerateApi(Resource):
|
||||
@console_ns.doc("generate_rule_code")
|
||||
@console_ns.doc(description="Generate code rules using LLM")
|
||||
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"RuleCodeGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Code generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
"code_language": fields.String(
|
||||
default="javascript", description="Programming language for code generation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Code rules generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@ -112,15 +95,22 @@ class RuleCodeGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.code_language,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -138,7 +128,15 @@ class RuleCodeGenerateApi(Resource):
|
||||
class RuleStructuredOutputGenerateApi(Resource):
|
||||
@console_ns.doc("generate_structured_output")
|
||||
@console_ns.doc(description="Generate structured output rules using LLM")
|
||||
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"StructuredOutputGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Structured output generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Structured output generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@ -146,14 +144,19 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -171,7 +174,20 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
class InstructionGenerateApi(Resource):
|
||||
@console_ns.doc("generate_instruction")
|
||||
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InstructionGenerateRequest",
|
||||
{
|
||||
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
|
||||
"node_id": fields.String(description="Node ID for workflow context"),
|
||||
"current": fields.String(description="Current instruction text"),
|
||||
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
|
||||
"instruction": fields.String(required=True, description="Instruction for generation"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Instruction generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@ -179,69 +195,79 @@ class InstructionGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||
.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||
.add_argument("current", type=str, required=False, default="", location="json")
|
||||
.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||
return {"error": f"workflow {args['flow_id']} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
node = [node for node in nodes if node["id"] == args.node_id]
|
||||
node = [node for node in nodes if node["id"] == args["node_id"]]
|
||||
if len(node) == 0:
|
||||
return {"error": f"node {args.node_id} not found"}, 400
|
||||
return {"error": f"node {args['node_id']} not found"}, 400
|
||||
node_type = node[0]["data"]["type"]
|
||||
match node_type:
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["language"],
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args.node_id == "" and args.current != "": # For legacy app without a workflow
|
||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
flow_id=args["flow_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
)
|
||||
if args.node_id != "" and args.current != "": # For workflow node
|
||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
node_id=args.node_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
flow_id=args["flow_id"],
|
||||
node_id=args["node_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
workflow_service=WorkflowService(),
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
@ -259,15 +285,24 @@ class InstructionGenerateApi(Resource):
|
||||
class InstructionGenerationTemplateApi(Resource):
|
||||
@console_ns.doc("get_instruction_template")
|
||||
@console_ns.doc(description="Get instruction generation template")
|
||||
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InstructionTemplateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Template instruction"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Template retrieved successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = InstructionTemplatePayload.model_validate(console_ns.payload)
|
||||
match args.type:
|
||||
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
match args["type"]:
|
||||
case "prompt":
|
||||
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
|
||||
|
||||
@ -277,4 +312,4 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
raise ValueError(f"Invalid type: {args['type']}")
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
@ -35,67 +33,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ChatMessagesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID")
|
||||
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
@field_validator("first_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
@field_validator("conversation_id", "first_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class FeedbackExportQuery(BaseModel):
|
||||
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
|
||||
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
|
||||
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
|
||||
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
|
||||
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
|
||||
|
||||
@field_validator("has_comment", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool | None:
|
||||
if isinstance(value, bool) or value is None:
|
||||
return value
|
||||
lowered = value.lower()
|
||||
if lowered in {"true", "1", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"false", "0", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("has_comment must be a boolean value")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(ChatMessagesQuery)
|
||||
reg(MessageFeedbackPayload)
|
||||
reg(FeedbackExportQuery)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -220,7 +157,12 @@ class ChatMessageListApi(Resource):
|
||||
@console_ns.doc("list_chat_messages")
|
||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
|
||||
)
|
||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@login_required
|
||||
@ -230,21 +172,27 @@ class ChatMessageListApi(Resource):
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
.add_argument("first_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args.first_id:
|
||||
if args["first_id"]:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -259,7 +207,7 @@ class ChatMessageListApi(Resource):
|
||||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.limit(args["limit"])
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
@ -267,12 +215,12 @@ class ChatMessageListApi(Resource):
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.limit(args["limit"])
|
||||
.all()
|
||||
)
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args.limit:
|
||||
if len(history_messages) == args["limit"]:
|
||||
current_page_first_message = history_messages[-1]
|
||||
# Check if there are more messages before the current page
|
||||
has_more = db.session.scalar(
|
||||
@ -290,7 +238,7 @@ class ChatMessageListApi(Resource):
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
@ -298,7 +246,15 @@ class MessageFeedbackApi(Resource):
|
||||
@console_ns.doc("create_message_feedback")
|
||||
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MessageFeedbackRequest",
|
||||
{
|
||||
"message_id": fields.String(required=True, description="Message ID"),
|
||||
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Feedback updated successfully")
|
||||
@console_ns.response(404, "Message not found")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@ -309,9 +265,14 @@ class MessageFeedbackApi(Resource):
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = str(args.message_id)
|
||||
message_id = str(args["message_id"])
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
@ -320,21 +281,18 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
feedback = message.admin_feedback
|
||||
|
||||
if not args.rating and feedback:
|
||||
if not args["rating"] and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
elif not args.rating and not feedback:
|
||||
elif args["rating"] and feedback:
|
||||
feedback.rating = args["rating"]
|
||||
elif not args["rating"] and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
rating_value = args.rating
|
||||
if rating_value is None:
|
||||
raise ValueError("rating is required to create feedback")
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating_value,
|
||||
rating=args["rating"],
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
@ -411,12 +369,24 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
# Shared parser for feedback export (used for both documentation and runtime parsing)
|
||||
feedback_export_parser = (
|
||||
console_ns.parser()
|
||||
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
|
||||
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
|
||||
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
|
||||
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
|
||||
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
|
||||
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
|
||||
class MessageFeedbackExportApi(Resource):
|
||||
@console_ns.doc("export_feedbacks")
|
||||
@console_ns.doc(description="Export user feedback data for Google Sheets")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
|
||||
@console_ns.expect(feedback_export_parser)
|
||||
@console_ns.response(200, "Feedback data exported successfully")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
@ -425,7 +395,7 @@ class MessageFeedbackExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = feedback_export_parser.parse_args()
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
@ -433,12 +403,12 @@ class MessageFeedbackExportApi(Resource):
|
||||
try:
|
||||
export_data = FeedbackService.export_feedbacks(
|
||||
app_id=app_model.id,
|
||||
from_source=args.from_source,
|
||||
rating=args.rating,
|
||||
has_comment=args.has_comment,
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
format_type=args.format,
|
||||
from_source=args.get("from_source"),
|
||||
rating=args.get("rating"),
|
||||
has_comment=args.get("has_comment"),
|
||||
start_date=args.get("start_date"),
|
||||
end_date=args.get("end_date"),
|
||||
format_type=args.get("format", "csv"),
|
||||
)
|
||||
|
||||
return export_data
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from decimal import Decimal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@ -11,37 +10,21 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.helper import DatetimeString, convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
StatisticTimeRangeQuery.__name__,
|
||||
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
class DailyMessageStatistic(Resource):
|
||||
@console_ns.doc("get_daily_message_statistics")
|
||||
@console_ns.doc(description="Get daily message statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily message statistics retrieved successfully",
|
||||
@ -54,7 +37,12 @@ class DailyMessageStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -69,7 +57,7 @@ WHERE
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -93,12 +81,19 @@ WHERE
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@console_ns.doc("get_daily_conversation_statistics")
|
||||
@console_ns.doc(description="Get daily conversation statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
@ -111,7 +106,7 @@ class DailyConversationStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -126,7 +121,7 @@ WHERE
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -154,7 +149,7 @@ class DailyTerminalsStatistic(Resource):
|
||||
@console_ns.doc("get_daily_terminals_statistics")
|
||||
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
@ -167,7 +162,7 @@ class DailyTerminalsStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -182,7 +177,7 @@ WHERE
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -211,7 +206,7 @@ class DailyTokenCostStatistic(Resource):
|
||||
@console_ns.doc("get_daily_token_cost_statistics")
|
||||
@console_ns.doc(description="Get daily token cost statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
@ -224,7 +219,7 @@ class DailyTokenCostStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -240,7 +235,7 @@ WHERE
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -271,7 +266,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@console_ns.doc("get_average_session_interaction_statistics")
|
||||
@console_ns.doc(description="Get average session interaction statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
@ -284,7 +279,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -307,7 +302,7 @@ FROM
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -347,7 +342,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@console_ns.doc("get_user_satisfaction_rate_statistics")
|
||||
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
@ -360,7 +355,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -379,7 +374,7 @@ WHERE
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -413,7 +408,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@console_ns.doc("get_average_response_time_statistics")
|
||||
@console_ns.doc(description="Get average response time statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
@ -426,7 +421,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -441,7 +436,7 @@ WHERE
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -470,7 +465,7 @@ class TokensPerSecondStatistic(Resource):
|
||||
@console_ns.doc("get_tokens_per_second_statistics")
|
||||
@console_ns.doc(description="Get tokens per second statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
@ -482,7 +477,7 @@ class TokensPerSecondStatistic(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
@ -500,7 +495,7 @@ WHERE
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -50,7 +49,6 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -109,104 +107,6 @@ if workflow_run_node_execution_model is None:
|
||||
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
|
||||
|
||||
|
||||
class SyncDraftWorkflowPayload(BaseModel):
|
||||
graph: dict[str, Any]
|
||||
features: dict[str, Any]
|
||||
hash: str | None = None
|
||||
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BaseWorkflowRunPayload(BaseModel):
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any] | None = None
|
||||
query: str = ""
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class IterationNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LoopNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any]
|
||||
|
||||
|
||||
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class ConvertToWorkflowPayload(BaseModel):
|
||||
name: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||
node_ids: list[str]
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(SyncDraftWorkflowPayload)
|
||||
reg(AdvancedChatWorkflowRunPayload)
|
||||
reg(IterationNodeRunPayload)
|
||||
reg(LoopNodeRunPayload)
|
||||
reg(DraftWorkflowRunPayload)
|
||||
reg(DraftWorkflowNodeRunPayload)
|
||||
reg(PublishWorkflowPayload)
|
||||
reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
# at the controller level rather than in the workflow logic. This would improve separation
|
||||
# of concerns and make the code more maintainable.
|
||||
@ -258,7 +158,18 @@ class DraftWorkflowApi(Resource):
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@console_ns.doc("sync_draft_workflow")
|
||||
@console_ns.doc(description="Sync draft workflow configuration")
|
||||
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SyncDraftWorkflowRequest",
|
||||
{
|
||||
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
|
||||
"features": fields.Raw(required=True, description="Workflow features configuration"),
|
||||
"hash": fields.String(description="Workflow hash for validation"),
|
||||
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Draft workflow synced successfully",
|
||||
@ -282,23 +193,36 @@ class DraftWorkflowApi(Resource):
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
payload_data: dict[str, Any] | None = None
|
||||
if "application/json" in content_type:
|
||||
payload_data = request.get_json(silent=True)
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("hash", type=str, required=False, location="json")
|
||||
.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
payload_data = json.loads(request.data.decode("utf-8"))
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
|
||||
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
|
||||
raise ValueError("graph or features is not a dict")
|
||||
|
||||
args = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||
args = args_model.model_dump()
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
@ -334,7 +258,17 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_advanced_chat_draft_workflow")
|
||||
@console_ns.doc(description="Run draft workflow for advanced chat application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AdvancedChatWorkflowRunRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
"files": fields.List(fields.Raw, description="File uploads"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Workflow run started successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@ -349,8 +283,16 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json", default="")
|
||||
.add_argument("files", type=list, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
@ -380,7 +322,15 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.doc("run_advanced_chat_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for advanced chat")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"IterationNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Iteration node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@ -394,7 +344,8 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
@ -418,7 +369,15 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.doc("run_workflow_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WorkflowIterationNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Workflow iteration node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@ -432,7 +391,8 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
@ -456,7 +416,15 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.doc("run_advanced_chat_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for advanced chat")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"LoopNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Loop node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@ -470,7 +438,8 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
@ -494,7 +463,15 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.doc("run_workflow_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WorkflowLoopNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Workflow loop node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@ -508,7 +485,8 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
@ -532,7 +510,15 @@ class DraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_draft_workflow")
|
||||
@console_ns.doc(description="Run draft workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowRunRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"files": fields.List(fields.Raw, description="File uploads"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Draft workflow run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -545,7 +531,12 @@ class DraftWorkflowRunApi(Resource):
|
||||
Run draft workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
@ -597,7 +588,14 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
@console_ns.doc("run_draft_workflow_node")
|
||||
@console_ns.doc(description="Run draft workflow node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowNodeRunRequest",
|
||||
{
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@ -612,10 +610,15 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
Run draft workflow node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("query", type=str, required=False, location="json", default="")
|
||||
.add_argument("files", type=list, location="json", default=[])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_inputs = args_model.inputs
|
||||
user_inputs = args.get("inputs")
|
||||
if user_inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
@ -640,6 +643,13 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
parser_publish = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
||||
class PublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_published_workflow")
|
||||
@ -664,7 +674,7 @@ class PublishedWorkflowApi(Resource):
|
||||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
|
||||
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||
@console_ns.expect(parser_publish)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -676,7 +686,13 @@ class PublishedWorkflowApi(Resource):
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
args = parser_publish.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
@ -725,6 +741,9 @@ class DefaultBlockConfigsApi(Resource):
|
||||
return workflow_service.get_default_block_configs()
|
||||
|
||||
|
||||
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultBlockConfigApi(Resource):
|
||||
@console_ns.doc("get_default_block_config")
|
||||
@ -732,7 +751,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
||||
@console_ns.response(200, "Default block configuration retrieved successfully")
|
||||
@console_ns.response(404, "Block type not found")
|
||||
@console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
|
||||
@console_ns.expect(parser_block)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -742,12 +761,14 @@ class DefaultBlockConfigApi(Resource):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = parser_block.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
||||
filters = None
|
||||
if args.q:
|
||||
if q:
|
||||
try:
|
||||
filters = json.loads(args.q)
|
||||
filters = json.loads(args.get("q", ""))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid filters")
|
||||
|
||||
@ -756,9 +777,18 @@ class DefaultBlockConfigApi(Resource):
|
||||
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_convert = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
|
||||
@console_ns.expect(parser_convert)
|
||||
@console_ns.doc("convert_to_workflow")
|
||||
@console_ns.doc(description="Convert application to workflow mode")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@ -778,8 +808,10 @@ class ConvertToWorkflowApi(Resource):
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = console_ns.payload or {}
|
||||
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
|
||||
if request.data:
|
||||
args = parser_convert.parse_args()
|
||||
else:
|
||||
args = {}
|
||||
|
||||
# convert to workflow mode
|
||||
workflow_service = WorkflowService()
|
||||
@ -791,9 +823,18 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser_workflows = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@console_ns.expect(parser_workflows)
|
||||
@console_ns.doc("get_all_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@ -810,15 +851,16 @@ class PublishedAllWorkflowApi(Resource):
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
user_id = args.user_id
|
||||
named_only = args.named_only
|
||||
args = parser_workflows.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
|
||||
if user_id:
|
||||
if user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
user_id = cast(str, user_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
@ -844,7 +886,15 @@ class WorkflowByIdApi(Resource):
|
||||
@console_ns.doc("update_workflow_by_id")
|
||||
@console_ns.doc(description="Update workflow by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateWorkflowRequest",
|
||||
{
|
||||
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Workflow updated successfully", workflow_model)
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@ -859,14 +909,25 @@ class WorkflowByIdApi(Resource):
|
||||
Update workflow attributes
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if args.marked_name is not None:
|
||||
update_data["marked_name"] = args.marked_name
|
||||
if args.marked_comment is not None:
|
||||
update_data["marked_comment"] = args.marked_comment
|
||||
if args.get("marked_name") is not None:
|
||||
update_data["marked_name"] = args["marked_name"]
|
||||
if args.get("marked_comment") is not None:
|
||||
update_data["marked_comment"] = args["marked_comment"]
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
@ -979,8 +1040,11 @@ class DraftWorkflowTriggerRunApi(Resource):
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
|
||||
node_id = args.node_id
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_id", type=str, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
@ -1108,7 +1172,14 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
@console_ns.doc("draft_workflow_trigger_run_all")
|
||||
@console_ns.doc(description="Full workflow debug when the start node is a trigger")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowTriggerRunAllRequest",
|
||||
{
|
||||
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Workflow executed successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
@ -1123,8 +1194,11 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
|
||||
node_ids = args.node_ids
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_ids", type=list, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_ids = args["node_ids"]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
|
||||
@ -1,9 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -17,48 +14,6 @@ from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowAppLogQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
|
||||
status: WorkflowExecutionStatus | None = Field(
|
||||
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
|
||||
)
|
||||
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
|
||||
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
|
||||
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
|
||||
created_by_account: str | None = Field(default=None, description="Filter by account")
|
||||
detail: bool = Field(default=False, description="Whether to return detailed logs")
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
|
||||
@field_validator("created_at__before", "created_at__after", mode="before")
|
||||
@classmethod
|
||||
def parse_datetime(cls, value: str | None) -> datetime | None:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
return isoparse(value) # type: ignore
|
||||
|
||||
@field_validator("detail", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
lowered = value.lower()
|
||||
if lowered in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("Invalid boolean value for detail")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
|
||||
@ -68,7 +23,19 @@ class WorkflowAppLogApi(Resource):
|
||||
@console_ns.doc("get_workflow_app_logs")
|
||||
@console_ns.doc(description="Get workflow application execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"keyword": "Search keyword for filtering logs",
|
||||
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
|
||||
"created_at__before": "Filter logs created before this timestamp",
|
||||
"created_at__after": "Filter logs created after this timestamp",
|
||||
"created_by_end_user_session_id": "Filter by end user session ID",
|
||||
"created_by_account": "Filter by account",
|
||||
"detail": "Whether to return detailed logs",
|
||||
"page": "Page number (1-99999)",
|
||||
"limit": "Number of items per page (1-100)",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -79,7 +46,44 @@ class WorkflowAppLogApi(Resource):
|
||||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("detail", type=bool, location="args", required=False, default=False)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from typing import Literal, cast
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@ -93,51 +92,70 @@ workflow_run_node_execution_list_model = console_ns.model(
|
||||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
def _parse_workflow_run_list_args():
|
||||
"""
|
||||
Parse common arguments for workflow run list endpoints.
|
||||
|
||||
class WorkflowRunListQuery(BaseModel):
|
||||
last_id: str | None = Field(default=None, description="Last run ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
Returns:
|
||||
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
)
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_workflow_run_count_args():
|
||||
"""
|
||||
Parse common arguments for workflow run count endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing status, time_range, and triggered_from filters
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"time_range",
|
||||
type=time_duration,
|
||||
location="args",
|
||||
required=False,
|
||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
)
|
||||
|
||||
@field_validator("last_id")
|
||||
@classmethod
|
||||
def validate_last_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class WorkflowRunCountQuery(BaseModel):
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
|
||||
@field_validator("time_range")
|
||||
@classmethod
|
||||
def validate_time_range(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return time_duration(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowRunCountQuery.__name__,
|
||||
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
@ -152,7 +170,6 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -163,13 +180,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args = _parse_workflow_run_list_args()
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
@ -201,7 +217,6 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -211,13 +226,12 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
"""
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args = _parse_workflow_run_count_args()
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
@ -245,7 +259,6 @@ class WorkflowRunListApi(Resource):
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -255,13 +268,12 @@ class WorkflowRunListApi(Resource):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args = _parse_workflow_run_list_args()
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
@ -293,7 +305,6 @@ class WorkflowRunCountApi(Resource):
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -303,13 +314,12 @@ class WorkflowRunCountApi(Resource):
|
||||
"""
|
||||
Get workflow runs count statistics
|
||||
"""
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args = _parse_workflow_run_count_args()
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -8,31 +7,12 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowStatisticQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowStatisticQuery.__name__,
|
||||
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
@ -44,7 +24,9 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
@console_ns.doc("get_workflow_daily_runs_statistic")
|
||||
@console_ns.doc(description="Get workflow daily runs statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.response(200, "Daily runs statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@ -53,12 +35,17 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -84,7 +71,9 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@console_ns.doc("get_workflow_daily_terminals_statistic")
|
||||
@console_ns.doc(description="Get workflow daily terminals statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@ -93,12 +82,17 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -124,7 +118,9 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@console_ns.doc("get_workflow_daily_token_cost_statistic")
|
||||
@console_ns.doc(description="Get workflow daily token cost statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@ -133,12 +129,17 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -164,7 +165,9 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@console_ns.doc("get_workflow_average_app_interaction_statistic")
|
||||
@console_ns.doc(description="Get workflow average app interaction statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -173,12 +176,17 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ class VersionApi(Resource):
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args["current_version"]},
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
timeout=httpx.Timeout(connect=3, read=10),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
|
||||
@ -174,25 +174,63 @@ class CheckEmailUniquePayload(BaseModel):
|
||||
return email(value)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AccountInitPayload)
|
||||
reg(AccountNamePayload)
|
||||
reg(AccountAvatarPayload)
|
||||
reg(AccountInterfaceLanguagePayload)
|
||||
reg(AccountInterfaceThemePayload)
|
||||
reg(AccountTimezonePayload)
|
||||
reg(AccountPasswordPayload)
|
||||
reg(AccountDeletePayload)
|
||||
reg(AccountDeletionFeedbackPayload)
|
||||
reg(EducationActivatePayload)
|
||||
reg(EducationAutocompleteQuery)
|
||||
reg(ChangeEmailSendPayload)
|
||||
reg(ChangeEmailValidityPayload)
|
||||
reg(ChangeEmailResetPayload)
|
||||
reg(CheckEmailUniquePayload)
|
||||
console_ns.schema_model(
|
||||
AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountInterfaceLanguagePayload.__name__,
|
||||
AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountInterfaceThemePayload.__name__,
|
||||
AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountTimezonePayload.__name__,
|
||||
AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountPasswordPayload.__name__,
|
||||
AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountDeletePayload.__name__,
|
||||
AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
AccountDeletionFeedbackPayload.__name__,
|
||||
AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EducationActivatePayload.__name__,
|
||||
EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EducationAutocompleteQuery.__name__,
|
||||
EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailSendPayload.__name__,
|
||||
ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailValidityPayload.__name__,
|
||||
ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChangeEmailResetPayload.__name__,
|
||||
ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CheckEmailUniquePayload.__name__,
|
||||
CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
|
||||
@ -1,8 +1,4 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
@ -11,49 +7,21 @@ from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class EndpointCreatePayload(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class EndpointIdPayload(BaseModel):
|
||||
endpoint_id: str
|
||||
|
||||
|
||||
class EndpointUpdatePayload(EndpointIdPayload):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class EndpointListQuery(BaseModel):
|
||||
page: int = Field(ge=1)
|
||||
page_size: int = Field(gt=0)
|
||||
|
||||
|
||||
class EndpointListForPluginQuery(EndpointListQuery):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(EndpointCreatePayload)
|
||||
reg(EndpointIdPayload)
|
||||
reg(EndpointUpdatePayload)
|
||||
reg(EndpointListQuery)
|
||||
reg(EndpointListForPluginQuery)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@console_ns.doc("create_endpoint")
|
||||
@console_ns.doc(description="Create a new plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointCreateRequest",
|
||||
{
|
||||
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
|
||||
"settings": fields.Raw(required=True, description="Endpoint settings"),
|
||||
"name": fields.String(required=True, description="Endpoint name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
@ -67,16 +35,26 @@ class EndpointCreateApi(Resource):
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True)
|
||||
.add_argument("settings", type=dict, required=True)
|
||||
.add_argument("name", type=str, required=True)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_unique_identifier = args["plugin_unique_identifier"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
@ -87,7 +65,11 @@ class EndpointCreateApi(Resource):
|
||||
class EndpointListApi(Resource):
|
||||
@console_ns.doc("list_endpoints")
|
||||
@console_ns.doc(description="List plugin endpoints with pagination")
|
||||
@console_ns.expect(console_ns.models[EndpointListQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
@ -101,10 +83,15 @@ class EndpointListApi(Resource):
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args.page
|
||||
page_size = args.page_size
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
@ -122,7 +109,12 @@ class EndpointListApi(Resource):
|
||||
class EndpointListForSinglePluginApi(Resource):
|
||||
@console_ns.doc("list_plugin_endpoints")
|
||||
@console_ns.doc(description="List endpoints for a specific plugin")
|
||||
@console_ns.expect(console_ns.models[EndpointListForPluginQuery.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
@ -136,11 +128,17 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args.page
|
||||
page_size = args.page_size
|
||||
plugin_id = args.plugin_id
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
plugin_id = args["plugin_id"]
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
@ -159,7 +157,11 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
class EndpointDeleteApi(Resource):
|
||||
@console_ns.doc("delete_endpoint")
|
||||
@console_ns.doc(description="Delete a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
@ -173,12 +175,13 @@ class EndpointDeleteApi(Resource):
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
@ -186,7 +189,16 @@ class EndpointDeleteApi(Resource):
|
||||
class EndpointUpdateApi(Resource):
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointUpdateRequest",
|
||||
{
|
||||
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
|
||||
"settings": fields.Raw(required=True, description="Updated settings"),
|
||||
"name": fields.String(required=True, description="Updated name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
@ -200,15 +212,25 @@ class EndpointUpdateApi(Resource):
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("endpoint_id", type=str, required=True)
|
||||
.add_argument("settings", type=dict, required=True)
|
||||
.add_argument("name", type=str, required=True)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=args.endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
}
|
||||
|
||||
@ -217,7 +239,11 @@ class EndpointUpdateApi(Resource):
|
||||
class EndpointEnableApi(Resource):
|
||||
@console_ns.doc("enable_endpoint")
|
||||
@console_ns.doc(description="Enable a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint enabled successfully",
|
||||
@ -231,12 +257,13 @@ class EndpointEnableApi(Resource):
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
@ -244,7 +271,11 @@ class EndpointEnableApi(Resource):
|
||||
class EndpointDisableApi(Resource):
|
||||
@console_ns.doc("disable_endpoint")
|
||||
@console_ns.doc(description="Disable a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint disabled successfully",
|
||||
@ -258,10 +289,11 @@ class EndpointDisableApi(Resource):
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
@ -58,15 +58,26 @@ class OwnerTransferPayload(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(MemberInvitePayload)
|
||||
reg(MemberRoleUpdatePayload)
|
||||
reg(OwnerTransferEmailPayload)
|
||||
reg(OwnerTransferCheckPayload)
|
||||
reg(OwnerTransferPayload)
|
||||
console_ns.schema_model(
|
||||
MemberInvitePayload.__name__,
|
||||
MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
MemberRoleUpdatePayload.__name__,
|
||||
MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferEmailPayload.__name__,
|
||||
OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferCheckPayload.__name__,
|
||||
OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
OwnerTransferPayload.__name__,
|
||||
OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
|
||||
@ -75,18 +75,44 @@ class ParserPreferredProviderType(BaseModel):
|
||||
preferred_provider_type: Literal["system", "custom"]
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
console_ns.schema_model(
|
||||
ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialId.__name__,
|
||||
ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
reg(ParserModelList)
|
||||
reg(ParserCredentialId)
|
||||
reg(ParserCredentialCreate)
|
||||
reg(ParserCredentialUpdate)
|
||||
reg(ParserCredentialDelete)
|
||||
reg(ParserCredentialSwitch)
|
||||
reg(ParserCredentialValidate)
|
||||
reg(ParserPreferredProviderType)
|
||||
console_ns.schema_model(
|
||||
ParserCredentialCreate.__name__,
|
||||
ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialUpdate.__name__,
|
||||
ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialDelete.__name__,
|
||||
ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialSwitch.__name__,
|
||||
ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserCredentialValidate.__name__,
|
||||
ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPreferredProviderType.__name__,
|
||||
ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
|
||||
@ -32,11 +32,25 @@ class ParserPostDefault(BaseModel):
|
||||
model_settings: list[Inner]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class ParserDeleteModels(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class LoadBalancingPayload(BaseModel):
|
||||
configs: list[dict[str, Any]] | None = None
|
||||
enabled: bool | None = None
|
||||
@ -105,19 +119,33 @@ class ParserParameter(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
console_ns.schema_model(
|
||||
ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGetCredentials.__name__,
|
||||
ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
reg(ParserGetDefault)
|
||||
reg(ParserPostDefault)
|
||||
reg(ParserDeleteModels)
|
||||
reg(ParserPostModels)
|
||||
reg(ParserGetCredentials)
|
||||
reg(ParserCreateCredential)
|
||||
reg(ParserUpdateCredential)
|
||||
reg(ParserDeleteCredential)
|
||||
reg(ParserParameter)
|
||||
console_ns.schema_model(
|
||||
ParserCreateCredential.__name__,
|
||||
ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserUpdateCredential.__name__,
|
||||
ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDeleteCredential.__name__,
|
||||
ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
|
||||
@ -22,10 +22,6 @@ from services.plugin.plugin_service import PluginService
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@ -50,7 +46,9 @@ class ParserList(BaseModel):
|
||||
page_size: int = Field(default=256)
|
||||
|
||||
|
||||
reg(ParserList)
|
||||
console_ns.schema_model(
|
||||
ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
@ -74,6 +72,11 @@ class ParserLatest(BaseModel):
|
||||
plugin_ids: list[str]
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class ParserIcon(BaseModel):
|
||||
tenant_id: str
|
||||
filename: str
|
||||
@ -170,22 +173,72 @@ class ParserReadme(BaseModel):
|
||||
language: str = Field(default="en-US")
|
||||
|
||||
|
||||
reg(ParserLatest)
|
||||
reg(ParserIcon)
|
||||
reg(ParserAsset)
|
||||
reg(ParserGithubUpload)
|
||||
reg(ParserPluginIdentifiers)
|
||||
reg(ParserGithubInstall)
|
||||
reg(ParserPluginIdentifierQuery)
|
||||
reg(ParserTasks)
|
||||
reg(ParserMarketplaceUpgrade)
|
||||
reg(ParserGithubUpgrade)
|
||||
reg(ParserUninstall)
|
||||
reg(ParserPermissionChange)
|
||||
reg(ParserDynamicOptions)
|
||||
reg(ParserPreferencesChange)
|
||||
reg(ParserExcludePlugin)
|
||||
reg(ParserReadme)
|
||||
console_ns.schema_model(
|
||||
ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPluginIdentifiers.__name__,
|
||||
ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPluginIdentifierQuery.__name__,
|
||||
ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserMarketplaceUpgrade.__name__,
|
||||
ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPermissionChange.__name__,
|
||||
ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserDynamicOptions.__name__,
|
||||
ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserPreferencesChange.__name__,
|
||||
ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserExcludePlugin.__name__,
|
||||
ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
|
||||
@ -54,14 +54,25 @@ class WorkspaceInfoPayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
console_ns.schema_model(
|
||||
WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
SwitchWorkspacePayload.__name__,
|
||||
SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceCustomConfigPayload.__name__,
|
||||
WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkspaceInfoPayload.__name__,
|
||||
WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
reg(WorkspaceListQuery)
|
||||
reg(SwitchWorkspacePayload)
|
||||
reg(WorkspaceCustomConfigPayload)
|
||||
reg(WorkspaceInfoPayload)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
|
||||
@ -316,16 +316,18 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
)
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
if hasattr(result, "rowcount") and result.rowcount > 0:
|
||||
session.commit()
|
||||
api_token = result.scalar_one_or_none()
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
api_token = session.scalar(stmt)
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
return api_token
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -56,7 +55,6 @@ from models import Account, EndUser
|
||||
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -291,30 +289,26 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
if event.node_type == NodeType.TOOL:
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||
self._application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
elif event.node_type == NodeType.TRIGGER_PLUGIN:
|
||||
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
except Exception:
|
||||
# metadata fetch may fail, for example, the plugin daemon is down or plugin is uninstalled.
|
||||
logger.warning("failed to fetch icon for %s", event.provider_id)
|
||||
if event.node_type == NodeType.TOOL:
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||
self._application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
elif event.node_type == NodeType.TRIGGER_PLUGIN:
|
||||
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -156,82 +156,78 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
query = application_generate_entity.query or "New conversation"
|
||||
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
||||
|
||||
with db.session.begin():
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=app_model_config_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_config.app_mode.value,
|
||||
name=conversation_name,
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.flush()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = naive_utc_now()
|
||||
|
||||
message = Message(
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=app_model_config_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
mode=app_config.app_mode.value,
|
||||
name=conversation_name,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
app_mode=app_config.app_mode,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.flush()
|
||||
db.session.refresh(message)
|
||||
|
||||
message_files = []
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
message_id=message.id,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
message_files.append(message_file)
|
||||
|
||||
if message_files:
|
||||
db.session.add_all(message_files)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
message = Message(
|
||||
app_id=app_config.app_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
app_mode=app_config.app_mode,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db.session.refresh(message)
|
||||
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
message_id=message.id,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
return conversation, message
|
||||
|
||||
def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:
|
||||
|
||||
@ -4,15 +4,15 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
"""
|
||||
@ -275,8 +275,10 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
# Import TraceQueueManager at runtime to resolve forward references
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
# Rebuild models that use forward references
|
||||
AppGenerateEntity.model_rebuild()
|
||||
EasyUIBasedAppGenerateEntity.model_rebuild()
|
||||
ConversationAppGenerateEntity.model_rebuild()
|
||||
|
||||
@ -29,7 +29,6 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: list[ModelType]
|
||||
|
||||
@ -43,7 +42,6 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
provider=provider_entity.provider,
|
||||
label=provider_entity.label,
|
||||
icon_small=provider_entity.icon_small,
|
||||
icon_small_dark=provider_entity.icon_small_dark,
|
||||
icon_large=provider_entity.icon_large,
|
||||
supported_model_types=provider_entity.supported_model_types,
|
||||
)
|
||||
|
||||
@ -99,7 +99,6 @@ class SimpleProviderEntity(BaseModel):
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
@ -125,6 +124,7 @@ class ProviderEntity(BaseModel):
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large_dark: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
|
||||
@ -300,14 +300,6 @@ class ModelProviderFactory:
|
||||
file_name = provider_schema.icon_small.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small.en_US
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small_dark.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small_dark.en_US
|
||||
else:
|
||||
if not provider_schema.icon_large:
|
||||
raise ValueError(f"Provider {provider} does not have large icon.")
|
||||
|
||||
@ -58,39 +58,11 @@ class OceanBaseVector(BaseVector):
|
||||
password=self._config.password,
|
||||
db_name=self._config.database,
|
||||
)
|
||||
self._fields: list[str] = [] # List of fields in the collection
|
||||
if self._client.check_table_exists(collection_name):
|
||||
self._load_collection_fields()
|
||||
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.OCEANBASE
|
||||
|
||||
def _load_collection_fields(self):
|
||||
"""
|
||||
Load collection fields from the database table.
|
||||
This method populates the _fields list with column names from the table.
|
||||
"""
|
||||
try:
|
||||
if self._collection_name in self._client.metadata_obj.tables:
|
||||
table = self._client.metadata_obj.tables[self._collection_name]
|
||||
# Store all column names except 'id' (primary key)
|
||||
self._fields = [column.name for column in table.columns if column.name != "id"]
|
||||
logger.debug("Loaded fields for collection '%s': %s", self._collection_name, self._fields)
|
||||
else:
|
||||
logger.warning("Collection '%s' not found in metadata", self._collection_name)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load collection fields for '%s': %s", self._collection_name, str(e))
|
||||
|
||||
def field_exists(self, field: str) -> bool:
|
||||
"""
|
||||
Check if a field exists in the collection.
|
||||
|
||||
:param field: Field name to check
|
||||
:return: True if field exists, False otherwise
|
||||
"""
|
||||
return field in self._fields
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self._vec_dim = len(embeddings[0])
|
||||
self._create_collection()
|
||||
@ -179,7 +151,6 @@ class OceanBaseVector(BaseVector):
|
||||
logger.debug("DEBUG: Hybrid search is NOT enabled for '%s'", self._collection_name)
|
||||
|
||||
self._client.refresh_metadata([self._collection_name])
|
||||
self._load_collection_fields()
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
@ -206,134 +177,42 @@ class OceanBaseVector(BaseVector):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = self._get_uuids(documents)
|
||||
for id, doc, emb in zip(ids, documents, embeddings):
|
||||
try:
|
||||
self._client.insert(
|
||||
table_name=self._collection_name,
|
||||
data={
|
||||
"id": id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to insert document with id '%s' in collection '%s'",
|
||||
id,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to insert document with id '{id}'") from e
|
||||
self._client.insert(
|
||||
table_name=self._collection_name,
|
||||
data={
|
||||
"id": id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
try:
|
||||
cur = self._client.get(table_name=self._collection_name, ids=id)
|
||||
return bool(cur.rowcount != 0)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to check if text exists with id '%s' in collection '%s'",
|
||||
id,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to check text existence for id '{id}'") from e
|
||||
cur = self._client.get(table_name=self._collection_name, ids=id)
|
||||
return bool(cur.rowcount != 0)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
try:
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
logger.debug("Deleted %d documents from collection '%s'", len(ids), self._collection_name)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to delete %d documents from collection '%s'",
|
||||
len(ids),
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to delete documents from collection '{self._collection_name}'") from e
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
try:
|
||||
import re
|
||||
from sqlalchemy import text
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
# Validate key to prevent injection in JSON path
|
||||
if not re.match(r"^[a-zA-Z0-9_.]+$", key):
|
||||
raise ValueError(f"Invalid characters in metadata key: {key}")
|
||||
|
||||
# Use parameterized query to prevent SQL injection
|
||||
sql = text(f"SELECT id FROM `{self._collection_name}` WHERE metadata->>'$.{key}' = :value")
|
||||
|
||||
with self._client.engine.connect() as conn:
|
||||
result = conn.execute(sql, {"value": value})
|
||||
ids = [row[0] for row in result]
|
||||
|
||||
logger.debug(
|
||||
"Found %d documents with metadata field '%s'='%s' in collection '%s'",
|
||||
len(ids),
|
||||
key,
|
||||
value,
|
||||
self._collection_name,
|
||||
)
|
||||
return ids
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to get IDs by metadata field '%s'='%s' in collection '%s'",
|
||||
key,
|
||||
value,
|
||||
self._collection_name,
|
||||
)
|
||||
raise Exception(f"Failed to query documents by metadata field '{key}'") from e
|
||||
cur = self._client.get(
|
||||
table_name=self._collection_name,
|
||||
ids=None,
|
||||
where_clause=[text(f"metadata->>'$.{key}' = '{value}'")],
|
||||
output_column_name=["id"],
|
||||
)
|
||||
return [row[0] for row in cur]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
else:
|
||||
logger.debug("No documents found to delete with metadata field '%s'='%s'", key, value)
|
||||
|
||||
def _process_search_results(
|
||||
self, results: list[tuple], score_threshold: float = 0.0, score_key: str = "score"
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Common method to process search results
|
||||
|
||||
:param results: Search results as list of tuples (text, metadata, score)
|
||||
:param score_threshold: Score threshold for filtering
|
||||
:param score_key: Key name for score in metadata
|
||||
:return: List of documents
|
||||
"""
|
||||
docs = []
|
||||
for row in results:
|
||||
text, metadata_str, score = row[0], row[1], row[2]
|
||||
|
||||
# Parse metadata JSON
|
||||
try:
|
||||
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON metadata: %s", metadata_str)
|
||||
metadata = {}
|
||||
|
||||
# Add score to metadata
|
||||
metadata[score_key] = score
|
||||
|
||||
# Filter by score threshold
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
if not self._hybrid_search_enabled:
|
||||
logger.warning(
|
||||
"Full-text search is disabled: set OCEANBASE_ENABLE_HYBRID_SEARCH=true (requires OceanBase >= 4.3.5.1)."
|
||||
)
|
||||
return []
|
||||
if not self.field_exists("text"):
|
||||
logger.warning(
|
||||
"Full-text search unavailable: collection '%s' missing 'text' field; "
|
||||
"recreate the collection after enabling OCEANBASE_ENABLE_HYBRID_SEARCH to add fulltext index.",
|
||||
self._collection_name,
|
||||
)
|
||||
return []
|
||||
|
||||
try:
|
||||
@ -341,24 +220,13 @@ class OceanBaseVector(BaseVector):
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
# Build parameterized query to prevent SQL injection
|
||||
from sqlalchemy import text
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
params = {"query": query}
|
||||
where_clause = ""
|
||||
|
||||
if document_ids_filter:
|
||||
# Create parameterized placeholders for document IDs
|
||||
placeholders = ", ".join(f":doc_id_{i}" for i in range(len(document_ids_filter)))
|
||||
where_clause = f" AND metadata->>'$.document_id' IN ({placeholders})"
|
||||
# Add document IDs to parameters
|
||||
for i, doc_id in enumerate(document_ids_filter):
|
||||
params[f"doc_id_{i}"] = doc_id
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'$.document_id' IN ({document_ids})"
|
||||
|
||||
full_sql = f"""SELECT text, metadata, MATCH (text) AGAINST (:query) AS score
|
||||
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score
|
||||
FROM {self._collection_name}
|
||||
WHERE MATCH (text) AGAINST (:query) > 0
|
||||
{where_clause}
|
||||
@ -367,45 +235,41 @@ class OceanBaseVector(BaseVector):
|
||||
|
||||
with self._client.engine.connect() as conn:
|
||||
with conn.begin():
|
||||
result = conn.execute(text(full_sql), params)
|
||||
from sqlalchemy import text
|
||||
|
||||
result = conn.execute(text(full_sql), {"query": query})
|
||||
rows = result.fetchall()
|
||||
|
||||
return self._process_search_results(rows, score_threshold=score_threshold)
|
||||
docs = []
|
||||
for row in rows:
|
||||
metadata_str, _text, score = row
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON metadata: %s", metadata_str)
|
||||
metadata = {}
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=_text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to perform full-text search on collection '%s' with query '%s'",
|
||||
self._collection_name,
|
||||
query,
|
||||
)
|
||||
raise Exception(f"Full-text search failed for collection '{self._collection_name}'") from e
|
||||
logger.warning("Failed to fulltext search: %s.", str(e))
|
||||
return []
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from sqlalchemy import text
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
_where_clause = None
|
||||
if document_ids_filter:
|
||||
# Validate document IDs to prevent SQL injection
|
||||
# Document IDs should be alphanumeric with hyphens and underscores
|
||||
import re
|
||||
|
||||
for doc_id in document_ids_filter:
|
||||
if not isinstance(doc_id, str) or not re.match(r"^[a-zA-Z0-9_-]+$", doc_id):
|
||||
raise ValueError(f"Invalid document ID format: {doc_id}")
|
||||
|
||||
# Safe to use in query after validation
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
|
||||
from sqlalchemy import text
|
||||
|
||||
_where_clause = [text(where_clause)]
|
||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||
if ef_search != self._hnsw_ef_search:
|
||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||
self._hnsw_ef_search = ef_search
|
||||
topk = kwargs.get("top_k", 10)
|
||||
try:
|
||||
score_threshold = float(val) if (val := kwargs.get("score_threshold")) is not None else 0.0
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid score_threshold parameter: {e}") from e
|
||||
try:
|
||||
cur = self._client.ann_search(
|
||||
table_name=self._collection_name,
|
||||
@ -418,27 +282,21 @@ class OceanBaseVector(BaseVector):
|
||||
where_clause=_where_clause,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to perform vector search on collection '%s'",
|
||||
self._collection_name,
|
||||
raise Exception("Failed to search by vector. ", e)
|
||||
docs = []
|
||||
for _text, metadata, distance in cur:
|
||||
metadata = json.loads(metadata)
|
||||
metadata["score"] = 1 - distance / math.sqrt(2)
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=_text,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
raise Exception(f"Vector search failed for collection '{self._collection_name}'") from e
|
||||
|
||||
# Convert distance to score and prepare results for processing
|
||||
results = []
|
||||
for _text, metadata_str, distance in cur:
|
||||
score = 1 - distance / math.sqrt(2)
|
||||
results.append((_text, metadata_str, score))
|
||||
|
||||
return self._process_search_results(results, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
try:
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
logger.debug("Dropped collection '%s'", self._collection_name)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to delete collection '%s'", self._collection_name)
|
||||
raise Exception(f"Failed to delete collection '{self._collection_name}'") from e
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
|
||||
|
||||
class OceanBaseVectorFactory(AbstractVectorFactory):
|
||||
|
||||
@ -54,8 +54,6 @@ class ToolProviderApiEntity(BaseModel):
|
||||
configuration: MCPConfiguration | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
# Workflow
|
||||
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
@ -89,8 +87,6 @@ class ToolProviderApiEntity(BaseModel):
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
elif self.type == ToolProviderType.WORKFLOW:
|
||||
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
|
||||
@ -203,7 +203,7 @@ class WorkflowTool(Tool):
|
||||
Resolve user object in both HTTP and worker contexts.
|
||||
|
||||
In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser).
|
||||
In worker context: load Account(knowledge pipeline) or EndUser(trigger) from database by user_id.
|
||||
In worker context: load Account from database by user_id (only returns Account, never EndUser).
|
||||
|
||||
Returns:
|
||||
Account | EndUser | None: The resolved user object, or None if resolution fails.
|
||||
@ -224,28 +224,24 @@ class WorkflowTool(Tool):
|
||||
logger.warning("Failed to resolve user from request context: %s", e)
|
||||
return None
|
||||
|
||||
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
|
||||
def _resolve_user_from_database(self, user_id: str) -> Account | None:
|
||||
"""
|
||||
Resolve user from database (worker/Celery context).
|
||||
"""
|
||||
|
||||
user_stmt = select(Account).where(Account.id == user_id)
|
||||
user = db.session.scalar(user_stmt)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
|
||||
tenant = db.session.scalar(tenant_stmt)
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
user_stmt = select(Account).where(Account.id == user_id)
|
||||
user = db.session.scalar(user_stmt)
|
||||
if user:
|
||||
user.current_tenant = tenant
|
||||
return user
|
||||
user.current_tenant = tenant
|
||||
|
||||
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
|
||||
end_user = db.session.scalar(end_user_stmt)
|
||||
if end_user:
|
||||
return end_user
|
||||
|
||||
return None
|
||||
return user
|
||||
|
||||
def _get_workflow(self, app_id: str, version: str) -> Workflow:
|
||||
"""
|
||||
|
||||
@ -1,11 +1,7 @@
|
||||
import importlib
|
||||
import logging
|
||||
import operator
|
||||
import pkgutil
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
@ -138,34 +134,6 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
cls._node_data_type = node_data_type
|
||||
|
||||
# Skip base class itself
|
||||
if cls is Node:
|
||||
return
|
||||
# Only register production node implementations defined under core.workflow.nodes.*
|
||||
# This prevents test helper subclasses from polluting the global registry and
|
||||
# accidentally overriding real node types (e.g., a test Answer node).
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
# Only register concrete subclasses that define node_type and version()
|
||||
node_type = cls.node_type
|
||||
version = cls.version()
|
||||
bucket = Node._registry.setdefault(node_type, {})
|
||||
if module_name.startswith("core.workflow.nodes."):
|
||||
# Production node definitions take precedence and may override
|
||||
bucket[version] = cls # type: ignore[index]
|
||||
else:
|
||||
# External/test subclasses may register but must not override production
|
||||
bucket.setdefault(version, cls) # type: ignore[index]
|
||||
# Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic
|
||||
version_keys = [v for v in bucket if v != "latest"]
|
||||
numeric_pairs: list[tuple[str, int]] = []
|
||||
for v in version_keys:
|
||||
numeric_pairs.append((v, int(v)))
|
||||
if numeric_pairs:
|
||||
latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0]
|
||||
else:
|
||||
latest_key = max(version_keys) if version_keys else version
|
||||
bucket["latest"] = bucket[latest_key]
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
"""
|
||||
@ -197,9 +165,6 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
return None
|
||||
|
||||
# Global registry populated via __init_subclass__
|
||||
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
@ -275,23 +240,23 @@ class Node(Generic[NodeDataT]):
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
plugin_id = getattr(self.node_data, "plugin_id", "")
|
||||
provider_name = getattr(self.node_data, "provider_name", "")
|
||||
plugin_id = getattr(self.get_base_node_data(), "plugin_id", "")
|
||||
provider_name = getattr(self.get_base_node_data(), "provider_name", "")
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
if isinstance(self, TriggerEventNode):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "")
|
||||
start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
@ -300,7 +265,7 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
if isinstance(self, AgentNode):
|
||||
start_event.agent_strategy = AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name,
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
@ -430,29 +395,6 @@ class Node(Generic[NodeDataT]):
|
||||
# in `api/core/workflow/nodes/__init__.py`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@classmethod
|
||||
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
||||
Then we return a readonly view of the registry to avoid accidental mutation.
|
||||
"""
|
||||
# Import all node modules to ensure they are loaded (thus registered)
|
||||
import core.workflow.nodes as _nodes_pkg
|
||||
|
||||
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports
|
||||
# e.g. node_factory imports node_mapping which builds the mapping here.
|
||||
if _modname in {
|
||||
"core.workflow.nodes.node_factory",
|
||||
"core.workflow.nodes.node_mapping",
|
||||
}:
|
||||
continue
|
||||
importlib.import_module(_modname)
|
||||
|
||||
# Return a readonly view so callers can't mutate the registry by accident
|
||||
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
@ -477,6 +419,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
return self._node_data
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> ErrorStrategy | None:
|
||||
@ -602,7 +548,7 @@ class Node(Generic[NodeDataT]):
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
@ -615,7 +561,7 @@ class Node(Generic[NodeDataT]):
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
node_title=self.get_base_node_data().title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.pre_loop_output,
|
||||
)
|
||||
@ -626,7 +572,7 @@ class Node(Generic[NodeDataT]):
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
@ -640,7 +586,7 @@ class Node(Generic[NodeDataT]):
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
@ -655,7 +601,7 @@ class Node(Generic[NodeDataT]):
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata,
|
||||
@ -668,7 +614,7 @@ class Node(Generic[NodeDataT]):
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
node_title=self.get_base_node_data().title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.pre_iteration_output,
|
||||
)
|
||||
@ -679,7 +625,7 @@ class Node(Generic[NodeDataT]):
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
@ -693,7 +639,7 @@ class Node(Generic[NodeDataT]):
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=self.node_data.title,
|
||||
node_title=self.get_base_node_data().title,
|
||||
start_at=event.start_at,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
|
||||
@ -1,9 +1,165 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.start import StartNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.trigger_plugin import TriggerEventNode
|
||||
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
|
||||
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
|
||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
|
||||
# Specifically, if you have introduced new node types, you should add them here.
|
||||
#
|
||||
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
|
||||
# hook. Try to avoid duplication of node information.
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
"1": StartNode,
|
||||
},
|
||||
NodeType.END: {
|
||||
LATEST_VERSION: EndNode,
|
||||
"1": EndNode,
|
||||
},
|
||||
NodeType.ANSWER: {
|
||||
LATEST_VERSION: AnswerNode,
|
||||
"1": AnswerNode,
|
||||
},
|
||||
NodeType.LLM: {
|
||||
LATEST_VERSION: LLMNode,
|
||||
"1": LLMNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: {
|
||||
LATEST_VERSION: KnowledgeRetrievalNode,
|
||||
"1": KnowledgeRetrievalNode,
|
||||
},
|
||||
NodeType.IF_ELSE: {
|
||||
LATEST_VERSION: IfElseNode,
|
||||
"1": IfElseNode,
|
||||
},
|
||||
NodeType.CODE: {
|
||||
LATEST_VERSION: CodeNode,
|
||||
"1": CodeNode,
|
||||
},
|
||||
NodeType.TEMPLATE_TRANSFORM: {
|
||||
LATEST_VERSION: TemplateTransformNode,
|
||||
"1": TemplateTransformNode,
|
||||
},
|
||||
NodeType.QUESTION_CLASSIFIER: {
|
||||
LATEST_VERSION: QuestionClassifierNode,
|
||||
"1": QuestionClassifierNode,
|
||||
},
|
||||
NodeType.HTTP_REQUEST: {
|
||||
LATEST_VERSION: HttpRequestNode,
|
||||
"1": HttpRequestNode,
|
||||
},
|
||||
NodeType.TOOL: {
|
||||
LATEST_VERSION: ToolNode,
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
||||
"2": ToolNode,
|
||||
"1": ToolNode,
|
||||
},
|
||||
NodeType.VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
},
|
||||
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
}, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: {
|
||||
LATEST_VERSION: IterationNode,
|
||||
"1": IterationNode,
|
||||
},
|
||||
NodeType.ITERATION_START: {
|
||||
LATEST_VERSION: IterationStartNode,
|
||||
"1": IterationStartNode,
|
||||
},
|
||||
NodeType.LOOP: {
|
||||
LATEST_VERSION: LoopNode,
|
||||
"1": LoopNode,
|
||||
},
|
||||
NodeType.LOOP_START: {
|
||||
LATEST_VERSION: LoopStartNode,
|
||||
"1": LoopStartNode,
|
||||
},
|
||||
NodeType.LOOP_END: {
|
||||
LATEST_VERSION: LoopEndNode,
|
||||
"1": LoopEndNode,
|
||||
},
|
||||
NodeType.PARAMETER_EXTRACTOR: {
|
||||
LATEST_VERSION: ParameterExtractorNode,
|
||||
"1": ParameterExtractorNode,
|
||||
},
|
||||
NodeType.VARIABLE_ASSIGNER: {
|
||||
LATEST_VERSION: VariableAssignerNodeV2,
|
||||
"1": VariableAssignerNodeV1,
|
||||
"2": VariableAssignerNodeV2,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: {
|
||||
LATEST_VERSION: DocumentExtractorNode,
|
||||
"1": DocumentExtractorNode,
|
||||
},
|
||||
NodeType.LIST_OPERATOR: {
|
||||
LATEST_VERSION: ListOperatorNode,
|
||||
"1": ListOperatorNode,
|
||||
},
|
||||
NodeType.AGENT: {
|
||||
LATEST_VERSION: AgentNode,
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
||||
"2": AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.HUMAN_INPUT: {
|
||||
LATEST_VERSION: HumanInputNode,
|
||||
"1": HumanInputNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
LATEST_VERSION: DatasourceNode,
|
||||
"1": DatasourceNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: {
|
||||
LATEST_VERSION: KnowledgeIndexNode,
|
||||
"1": KnowledgeIndexNode,
|
||||
},
|
||||
NodeType.TRIGGER_WEBHOOK: {
|
||||
LATEST_VERSION: TriggerWebhookNode,
|
||||
"1": TriggerWebhookNode,
|
||||
},
|
||||
NodeType.TRIGGER_PLUGIN: {
|
||||
LATEST_VERSION: TriggerEventNode,
|
||||
"1": TriggerEventNode,
|
||||
},
|
||||
NodeType.TRIGGER_SCHEDULE: {
|
||||
LATEST_VERSION: TriggerScheduleNode,
|
||||
"1": TriggerScheduleNode,
|
||||
},
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.enums import (
|
||||
@ -429,7 +430,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
}
|
||||
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
|
||||
if usage.total_tokens > 0:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||
@ -448,17 +449,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||
# Avoid importing WorkflowTool at module import time; rely on duck typing
|
||||
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
|
||||
latest = getattr(tool_runtime, "latest_usage", None)
|
||||
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
|
||||
# for any name, so we must type-check here.
|
||||
if isinstance(latest, LLMUsage):
|
||||
return latest
|
||||
if isinstance(latest, dict):
|
||||
# Allow dict payloads from external runtimes
|
||||
return LLMUsage.model_validate(latest)
|
||||
# Fallback to empty usage when attribute is missing or not a valid payload
|
||||
if isinstance(tool_runtime, WorkflowTool):
|
||||
return tool_runtime.latest_usage
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -6,7 +6,6 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD
|
||||
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization")
|
||||
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
@ -26,7 +25,6 @@ def init_app(app: DifyApp):
|
||||
service_api_bp,
|
||||
allow_headers=list(SERVICE_API_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
@ -36,7 +34,7 @@ def init_app(app: DifyApp):
|
||||
supports_credentials=True,
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
app.register_blueprint(web_bp)
|
||||
|
||||
@ -46,7 +44,7 @@ def init_app(app: DifyApp):
|
||||
supports_credentials=True,
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
@ -54,7 +52,6 @@ def init_app(app: DifyApp):
|
||||
files_bp,
|
||||
allow_headers=list(FILES_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
@ -66,6 +63,5 @@ def init_app(app: DifyApp):
|
||||
trigger_bp,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
)
|
||||
app.register_blueprint(trigger_bp)
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
import logging
|
||||
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Resolve Pydantic forward refs that would otherwise cause circular imports.
|
||||
|
||||
Rebuilds models in core.app.entities.app_invoke_entities with the real TraceQueueManager type.
|
||||
Safe to run multiple times.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
ConversationAppGenerateEntity,
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
RagPipelineGenerateEntity,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager # heavy import, do it at startup only
|
||||
|
||||
ns = {"TraceQueueManager": TraceQueueManager}
|
||||
for Model in (
|
||||
AppGenerateEntity,
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
ConversationAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
WorkflowAppGenerateEntity,
|
||||
RagPipelineGenerateEntity,
|
||||
):
|
||||
try:
|
||||
Model.model_rebuild(_types_namespace=ns)
|
||||
except Exception as e:
|
||||
logger.debug("model_rebuild skipped for %s: %s", Model.__name__, e)
|
||||
except Exception as e:
|
||||
# Don't block app startup; just log at debug level.
|
||||
logger.debug("ext_forward_refs init skipped: %s", e)
|
||||
@ -7,7 +7,6 @@ from logging.handlers import RotatingFileHandler
|
||||
import flask
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.trace_id_helper import get_trace_id_from_otel_context
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
@ -77,9 +76,7 @@ class RequestIdFilter(logging.Filter):
|
||||
# the logging format. Note that we're checking if we're in a request
|
||||
# context, as we may want to log things before Flask is fully loaded.
|
||||
def filter(self, record):
|
||||
trace_id = get_trace_id_from_otel_context() or ""
|
||||
record.req_id = get_request_id() if flask.has_request_context() else ""
|
||||
record.trace_id = trace_id
|
||||
return True
|
||||
|
||||
|
||||
@ -87,8 +84,6 @@ class RequestIdFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
if not hasattr(record, "trace_id"):
|
||||
record.trace_id = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
|
||||
@ -1,14 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import flask
|
||||
import werkzeug.http
|
||||
from flask import Flask, g
|
||||
from flask import Flask
|
||||
from flask.signals import request_finished, request_started
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.trace_id_helper import get_trace_id_from_otel_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -22,9 +20,6 @@ def _is_content_type_json(content_type: str) -> bool:
|
||||
|
||||
def _log_request_started(_sender, **_extra):
|
||||
"""Log the start of a request."""
|
||||
# Record start time for access logging
|
||||
g.__request_started_ts = time.perf_counter()
|
||||
|
||||
if not logger.isEnabledFor(logging.DEBUG):
|
||||
return
|
||||
|
||||
@ -47,39 +42,8 @@ def _log_request_started(_sender, **_extra):
|
||||
|
||||
|
||||
def _log_request_finished(_sender, response, **_extra):
|
||||
"""Log the end of a request.
|
||||
|
||||
Safe to call with or without an active Flask request context.
|
||||
"""
|
||||
if response is None:
|
||||
return
|
||||
|
||||
# Always emit a compact access line at INFO with trace_id so it can be grepped
|
||||
has_ctx = flask.has_request_context()
|
||||
start_ts = getattr(g, "__request_started_ts", None) if has_ctx else None
|
||||
duration_ms = None
|
||||
if start_ts is not None:
|
||||
duration_ms = round((time.perf_counter() - start_ts) * 1000, 3)
|
||||
|
||||
# Request attributes are available only when a request context exists
|
||||
if has_ctx:
|
||||
req_method = flask.request.method
|
||||
req_path = flask.request.path
|
||||
else:
|
||||
req_method = "-"
|
||||
req_path = "-"
|
||||
|
||||
trace_id = get_trace_id_from_otel_context() or response.headers.get("X-Trace-Id") or ""
|
||||
logger.info(
|
||||
"%s %s %s %s %s",
|
||||
req_method,
|
||||
req_path,
|
||||
getattr(response, "status_code", "-"),
|
||||
duration_ms if duration_ms is not None else "-",
|
||||
trace_id,
|
||||
)
|
||||
|
||||
if not logger.isEnabledFor(logging.DEBUG):
|
||||
"""Log the end of a request."""
|
||||
if not logger.isEnabledFor(logging.DEBUG) or response is None:
|
||||
return
|
||||
|
||||
if not _is_content_type_json(response.content_type):
|
||||
|
||||
@ -19,7 +19,7 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]):
|
||||
def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name in ["postgresql", "mysql"]:
|
||||
elif dialect.name == "postgresql":
|
||||
return str(value)
|
||||
else:
|
||||
if isinstance(value, uuid.UUID):
|
||||
|
||||
@ -111,7 +111,7 @@ package = false
|
||||
dev = [
|
||||
"coverage~=7.2.4",
|
||||
"dotenv-linter~=0.5.0",
|
||||
"faker~=38.2.0",
|
||||
"faker~=32.1.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"ty~=0.0.1a19",
|
||||
"basedpyright~=1.31.0",
|
||||
|
||||
@ -10,7 +10,6 @@ from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -1594,176 +1593,173 @@ class DocumentService:
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.flush()
|
||||
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
|
||||
try:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
assert dataset_process_rule
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
if not knowledge_config.data_source.info_list.file_info_list:
|
||||
raise ValueError("File source info is required")
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
assert dataset_process_rule
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
if not knowledge_config.data_source.info_list.file_info_list:
|
||||
raise ValueError("File source info is required")
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info: dict[str, str | bool] = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
# check duplicate
|
||||
if knowledge_config.duplicate:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
data_source_info: dict[str, str | bool] = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
# check duplicate
|
||||
if knowledge_config.duplicate:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
file_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if not notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
exist_document = {}
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
file_name,
|
||||
batch,
|
||||
)
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
exist_document[data_source_info["notion_page_id"]] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
for page in notion_info.pages:
|
||||
if page.page_id not in exist_page_ids:
|
||||
data_source_info = {
|
||||
"credential_id": notion_info.credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page.page_id,
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
|
||||
"type": page.type,
|
||||
}
|
||||
# Truncate page name to 255 characters to prevent DB field length errors
|
||||
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
truncated_page_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
else:
|
||||
exist_document.pop(page.page_id)
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
if not website_info:
|
||||
raise ValueError("No website info list found.")
|
||||
urls = website_info.urls
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
"url": url,
|
||||
"provider": website_info.provider,
|
||||
"job_id": website_info.job_id,
|
||||
"only_main_content": website_info.only_main_content,
|
||||
"mode": "crawl",
|
||||
}
|
||||
if len(url) > 255:
|
||||
document_name = url[:200] + "..."
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if not notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
exist_document = {}
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
exist_document[data_source_info["notion_page_id"]] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
for page in notion_info.pages:
|
||||
if page.page_id not in exist_page_ids:
|
||||
data_source_info = {
|
||||
"credential_id": notion_info.credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page.page_id,
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
|
||||
"type": page.type,
|
||||
}
|
||||
# Truncate page name to 255 characters to prevent DB field length errors
|
||||
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
truncated_page_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
else:
|
||||
document_name = url
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
document_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
db.session.commit()
|
||||
exist_document.pop(page.page_id)
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
if not website_info:
|
||||
raise ValueError("No website info list found.")
|
||||
urls = website_info.urls
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
"url": url,
|
||||
"provider": website_info.provider,
|
||||
"job_id": website_info.job_id,
|
||||
"only_main_content": website_info.only_main_content,
|
||||
"mode": "crawl",
|
||||
}
|
||||
if len(url) > 255:
|
||||
document_name = url[:200] + "..."
|
||||
else:
|
||||
document_name = url
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
document_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
if document_ids:
|
||||
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
|
||||
if duplicate_document_ids:
|
||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
# trigger async task
|
||||
if document_ids:
|
||||
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
|
||||
if duplicate_document_ids:
|
||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
|
||||
return documents, batch
|
||||
|
||||
@ -2703,55 +2699,50 @@ class SegmentService:
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
lock_name = f"add_segment_lock_document_id_{document.id}"
|
||||
try:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
status="completed",
|
||||
indexing_at=naive_utc_now(),
|
||||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
)
|
||||
if document.doc_form == "qa_model":
|
||||
segment_document.word_count += len(args["answer"])
|
||||
segment_document.answer = args["answer"]
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
status="completed",
|
||||
indexing_at=naive_utc_now(),
|
||||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
)
|
||||
if document.doc_form == "qa_model":
|
||||
segment_document.word_count += len(args["answer"])
|
||||
segment_document.answer = args["answer"]
|
||||
|
||||
db.session.add(segment_document)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += segment_document.word_count
|
||||
db.session.add(document)
|
||||
db.session.add(segment_document)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += segment_document.word_count
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
# save vector index
|
||||
try:
|
||||
VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form)
|
||||
except Exception as e:
|
||||
logger.exception("create segment index failed")
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
|
||||
# save vector index
|
||||
try:
|
||||
VectorService.create_segments_vector(
|
||||
[args["keywords"]], [segment_document], dataset, document.doc_form
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("create segment index failed")
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
|
||||
return segment
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
|
||||
return segment
|
||||
|
||||
@classmethod
|
||||
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
|
||||
@ -2760,89 +2751,84 @@ class SegmentService:
|
||||
|
||||
lock_name = f"multi_add_segment_lock_document_id_{document.id}"
|
||||
increment_word_count = 0
|
||||
try:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document.id)
|
||||
.scalar()
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
pre_segment_data_list = []
|
||||
segment_data_list = []
|
||||
keywords_list = []
|
||||
position = max_position + 1 if max_position else 1
|
||||
for segment_item in segments:
|
||||
content = segment_item["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality" and embedding_model:
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[content + segment_item["answer"]]
|
||||
)[0]
|
||||
else:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=position,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
keywords=segment_item.get("keywords", []),
|
||||
status="completed",
|
||||
indexing_at=naive_utc_now(),
|
||||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
)
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document.id)
|
||||
.scalar()
|
||||
)
|
||||
pre_segment_data_list = []
|
||||
segment_data_list = []
|
||||
keywords_list = []
|
||||
position = max_position + 1 if max_position else 1
|
||||
for segment_item in segments:
|
||||
content = segment_item["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == "high_quality" and embedding_model:
|
||||
# calc embedding use tokens
|
||||
if document.doc_form == "qa_model":
|
||||
segment_document.answer = segment_item["answer"]
|
||||
segment_document.word_count += len(segment_item["answer"])
|
||||
increment_word_count += segment_document.word_count
|
||||
db.session.add(segment_document)
|
||||
segment_data_list.append(segment_document)
|
||||
position += 1
|
||||
|
||||
pre_segment_data_list.append(segment_document)
|
||||
if "keywords" in segment_item:
|
||||
keywords_list.append(segment_item["keywords"])
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[content + segment_item["answer"]]
|
||||
)[0]
|
||||
else:
|
||||
keywords_list.append(None)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += increment_word_count
|
||||
db.session.add(document)
|
||||
try:
|
||||
# save vector index
|
||||
VectorService.create_segments_vector(
|
||||
keywords_list, pre_segment_data_list, dataset, document.doc_form
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("create segment index failed")
|
||||
for segment_document in segment_data_list:
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
return segment_data_list
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
|
||||
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=position,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
keywords=segment_item.get("keywords", []),
|
||||
status="completed",
|
||||
indexing_at=naive_utc_now(),
|
||||
completed_at=naive_utc_now(),
|
||||
created_by=current_user.id,
|
||||
)
|
||||
if document.doc_form == "qa_model":
|
||||
segment_document.answer = segment_item["answer"]
|
||||
segment_document.word_count += len(segment_item["answer"])
|
||||
increment_word_count += segment_document.word_count
|
||||
db.session.add(segment_document)
|
||||
segment_data_list.append(segment_document)
|
||||
position += 1
|
||||
|
||||
pre_segment_data_list.append(segment_document)
|
||||
if "keywords" in segment_item:
|
||||
keywords_list.append(segment_item["keywords"])
|
||||
else:
|
||||
keywords_list.append(None)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += increment_word_count
|
||||
db.session.add(document)
|
||||
try:
|
||||
# save vector index
|
||||
VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form)
|
||||
except Exception as e:
|
||||
logger.exception("create segment index failed")
|
||||
for segment_document in segment_data_list:
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = naive_utc_now()
|
||||
segment_document.status = "error"
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
return segment_data_list
|
||||
|
||||
@classmethod
|
||||
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
|
||||
@ -69,7 +69,6 @@ class ProviderResponse(BaseModel):
|
||||
label: I18nObject
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
@ -93,11 +92,6 @@ class ProviderResponse(BaseModel):
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
if self.icon_small_dark is not None:
|
||||
self.icon_small_dark = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small_dark/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans",
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
@ -115,7 +109,6 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
status: CustomConfigurationStatus
|
||||
models: list[ProviderModelWithStatusEntity]
|
||||
@ -130,11 +123,6 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_small_dark is not None:
|
||||
self.icon_small_dark = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
@ -159,11 +147,6 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_small_dark is not None:
|
||||
self.icon_small_dark = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
|
||||
@ -79,7 +79,6 @@ class ModelProviderService:
|
||||
label=provider_configuration.provider.label,
|
||||
description=provider_configuration.provider.description,
|
||||
icon_small=provider_configuration.provider.icon_small,
|
||||
icon_small_dark=provider_configuration.provider.icon_small_dark,
|
||||
icon_large=provider_configuration.provider.icon_large,
|
||||
background=provider_configuration.provider.background,
|
||||
help=provider_configuration.provider.help,
|
||||
@ -403,7 +402,6 @@ class ModelProviderService:
|
||||
provider=provider,
|
||||
label=first_model.provider.label,
|
||||
icon_small=first_model.provider.icon_small,
|
||||
icon_small_dark=first_model.provider.icon_small_dark,
|
||||
icon_large=first_model.provider.icon_large,
|
||||
status=CustomConfigurationStatus.ACTIVE,
|
||||
models=[
|
||||
|
||||
@ -201,9 +201,7 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_user_provider(
|
||||
provider_controller: WorkflowToolProviderController,
|
||||
labels: list[str] | None = None,
|
||||
workflow_app_id: str | None = None,
|
||||
provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
|
||||
):
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
@ -223,7 +221,6 @@ class ToolTransformService:
|
||||
plugin_unique_identifier=None,
|
||||
tools=[],
|
||||
labels=labels or [],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -189,9 +189,6 @@ class WorkflowToolManageService:
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
# Create a mapping from provider_id to app_id
|
||||
provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools}
|
||||
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
try:
|
||||
@ -205,11 +202,8 @@ class WorkflowToolManageService:
|
||||
result = []
|
||||
|
||||
for tool in tools:
|
||||
workflow_app_id = provider_id_to_app_id.get(tool.provider_id)
|
||||
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=tool,
|
||||
labels=labels.get(tool.provider_id, []),
|
||||
workflow_app_id=workflow_app_id,
|
||||
provider_controller=tool, labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_tool_provider)
|
||||
user_tool_provider.tools = [
|
||||
|
||||
@ -233,7 +233,7 @@ workflow:
|
||||
- value_selector:
|
||||
- iteration_node
|
||||
- output
|
||||
value_type: array[number]
|
||||
value_type: array[array[number]]
|
||||
variable: output
|
||||
selected: false
|
||||
title: End
|
||||
|
||||
@ -227,7 +227,6 @@ class TestModelProviderService:
|
||||
mock_provider_entity.label = {"en_US": "OpenAI", "zh_Hans": "OpenAI"}
|
||||
mock_provider_entity.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"}
|
||||
mock_provider_entity.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}
|
||||
mock_provider_entity.icon_small_dark = None
|
||||
mock_provider_entity.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}
|
||||
mock_provider_entity.background = "#FF6B6B"
|
||||
mock_provider_entity.help = None
|
||||
@ -301,7 +300,6 @@ class TestModelProviderService:
|
||||
mock_provider_entity_llm.label = {"en_US": "OpenAI", "zh_Hans": "OpenAI"}
|
||||
mock_provider_entity_llm.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"}
|
||||
mock_provider_entity_llm.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}
|
||||
mock_provider_entity_llm.icon_small_dark = None
|
||||
mock_provider_entity_llm.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}
|
||||
mock_provider_entity_llm.background = "#FF6B6B"
|
||||
mock_provider_entity_llm.help = None
|
||||
@ -315,7 +313,6 @@ class TestModelProviderService:
|
||||
mock_provider_entity_embedding.label = {"en_US": "Cohere", "zh_Hans": "Cohere"}
|
||||
mock_provider_entity_embedding.description = {"en_US": "Cohere provider", "zh_Hans": "Cohere 提供商"}
|
||||
mock_provider_entity_embedding.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}
|
||||
mock_provider_entity_embedding.icon_small_dark = None
|
||||
mock_provider_entity_embedding.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}
|
||||
mock_provider_entity_embedding.background = "#4ECDC4"
|
||||
mock_provider_entity_embedding.help = None
|
||||
@ -1026,7 +1023,6 @@ class TestModelProviderService:
|
||||
provider="openai",
|
||||
label={"en_US": "OpenAI", "zh_Hans": "OpenAI"},
|
||||
icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"},
|
||||
icon_small_dark=None,
|
||||
icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"},
|
||||
),
|
||||
model="gpt-3.5-turbo",
|
||||
@ -1044,7 +1040,6 @@ class TestModelProviderService:
|
||||
provider="openai",
|
||||
label={"en_US": "OpenAI", "zh_Hans": "OpenAI"},
|
||||
icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"},
|
||||
icon_small_dark=None,
|
||||
icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"},
|
||||
),
|
||||
model="gpt-4",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,100 +0,0 @@
|
||||
"""
|
||||
Unit tests for ToolProviderApiEntity workflow_app_id field.
|
||||
|
||||
This test suite covers:
|
||||
- ToolProviderApiEntity workflow_app_id field creation and default value
|
||||
- ToolProviderApiEntity.to_dict() method behavior with workflow_app_id
|
||||
"""
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class TestToolProviderApiEntityWorkflowAppId:
|
||||
"""Test suite for ToolProviderApiEntity workflow_app_id field."""
|
||||
|
||||
def test_workflow_app_id_field_default_none(self):
|
||||
"""Test that workflow_app_id defaults to None when not provided."""
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
)
|
||||
|
||||
assert entity.workflow_app_id is None
|
||||
|
||||
def test_to_dict_includes_workflow_app_id_when_workflow_type_and_has_value(self):
|
||||
"""Test that to_dict() includes workflow_app_id when type is WORKFLOW and value is set."""
|
||||
workflow_app_id = "app_123"
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" in result
|
||||
assert result["workflow_app_id"] == workflow_app_id
|
||||
|
||||
def test_to_dict_excludes_workflow_app_id_when_workflow_type_and_none(self):
|
||||
"""Test that to_dict() excludes workflow_app_id when type is WORKFLOW but value is None."""
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" not in result
|
||||
|
||||
def test_to_dict_excludes_workflow_app_id_when_not_workflow_type(self):
|
||||
"""Test that to_dict() excludes workflow_app_id when type is not WORKFLOW."""
|
||||
workflow_app_id = "app_123"
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.BUILT_IN,
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" not in result
|
||||
|
||||
def test_to_dict_includes_workflow_app_id_for_workflow_type_with_empty_string(self):
|
||||
"""Test that to_dict() excludes workflow_app_id when value is empty string (falsy)."""
|
||||
entity = ToolProviderApiEntity(
|
||||
id="test_id",
|
||||
author="test_author",
|
||||
name="test_name",
|
||||
description=I18nObject(en_US="Test description"),
|
||||
icon="test_icon",
|
||||
label=I18nObject(en_US="Test label"),
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
workflow_app_id="",
|
||||
)
|
||||
|
||||
result = entity.to_dict()
|
||||
|
||||
assert "workflow_app_id" not in result
|
||||
@ -1,5 +1,3 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -216,76 +214,3 @@ def test_create_variable_message():
|
||||
assert message.message.variable_name == var_name
|
||||
assert message.message.variable_value == var_value
|
||||
assert message.message.stream is False
|
||||
|
||||
|
||||
def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure worker context can resolve EndUser when Account is missing."""
|
||||
|
||||
class StubSession:
|
||||
def __init__(self, results: list):
|
||||
self.results = results
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return self.results.pop(0)
|
||||
|
||||
tenant = SimpleNamespace(id="tenant_id")
|
||||
end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id")
|
||||
db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user]))
|
||||
|
||||
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
|
||||
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="tenant_id", invoke_from=InvokeFrom.SERVICE_API)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
resolved_user = tool._resolve_user_from_database(user_id=end_user.id)
|
||||
|
||||
assert resolved_user is end_user
|
||||
|
||||
|
||||
def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Return None if tenant cannot be found in worker context."""
|
||||
|
||||
class StubSession:
|
||||
def __init__(self, results: list):
|
||||
self.results = results
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return self.results.pop(0)
|
||||
|
||||
db_stub = SimpleNamespace(session=StubSession([None]))
|
||||
monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub)
|
||||
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="missing_tenant", invoke_from=InvokeFrom.SERVICE_API)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
resolved_user = tool._resolve_user_from_database(user_id="any")
|
||||
|
||||
assert resolved_user is None
|
||||
|
||||
@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]):
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
return "test"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -744,7 +744,7 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered():
|
||||
)
|
||||
|
||||
llm_node = graph.nodes["llm"]
|
||||
base_node_data = llm_node.node_data
|
||||
base_node_data = llm_node.get_base_node_data()
|
||||
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
|
||||
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]
|
||||
|
||||
|
||||
@ -7,31 +7,9 @@ This module tests the iteration node's ability to:
|
||||
"""
|
||||
|
||||
from .test_database_utils import skip_if_database_unavailable
|
||||
from .test_mock_config import MockConfigBuilder, NodeMockConfig
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _create_iteration_mock_config():
|
||||
"""Helper to create a mock config for iteration tests."""
|
||||
|
||||
def code_inner_handler(node):
|
||||
pool = node.graph_runtime_state.variable_pool
|
||||
item_seg = pool.get(["iteration_node", "item"])
|
||||
if item_seg is not None:
|
||||
item = item_seg.to_object()
|
||||
return {"result": [item, item * 2]}
|
||||
# This fallback is likely unreachable, but if it is,
|
||||
# it doesn't simulate iteration with different values as the comment suggests.
|
||||
return {"result": [1, 2]}
|
||||
|
||||
return (
|
||||
MockConfigBuilder()
|
||||
.with_node_output("code_node", {"result": [1, 2, 3]})
|
||||
.with_node_config(NodeMockConfig(node_id="code_inner_node", custom_handler=code_inner_handler))
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
@skip_if_database_unavailable()
|
||||
def test_iteration_with_flatten_output_enabled():
|
||||
"""
|
||||
@ -49,8 +27,7 @@ def test_iteration_with_flatten_output_enabled():
|
||||
inputs={},
|
||||
expected_outputs={"output": [1, 2, 2, 4, 3, 6]},
|
||||
description="Iteration with flatten_output=True flattens nested arrays",
|
||||
use_auto_mock=True, # Use auto-mock to avoid sandbox service
|
||||
mock_config=_create_iteration_mock_config(),
|
||||
use_auto_mock=False, # Run code nodes directly
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
@ -79,8 +56,7 @@ def test_iteration_with_flatten_output_disabled():
|
||||
inputs={},
|
||||
expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]},
|
||||
description="Iteration with flatten_output=False preserves nested structure",
|
||||
use_auto_mock=True, # Use auto-mock to avoid sandbox service
|
||||
mock_config=_create_iteration_mock_config(),
|
||||
use_auto_mock=False, # Run code nodes directly
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
@ -105,16 +81,14 @@ def test_iteration_flatten_output_comparison():
|
||||
inputs={},
|
||||
expected_outputs={"output": [1, 2, 2, 4, 3, 6]},
|
||||
description="flatten_output=True: Flattened output",
|
||||
use_auto_mock=True, # Use auto-mock to avoid sandbox service
|
||||
mock_config=_create_iteration_mock_config(),
|
||||
use_auto_mock=False, # Run code nodes directly
|
||||
),
|
||||
WorkflowTestCase(
|
||||
fixture_path="iteration_flatten_output_disabled_workflow",
|
||||
inputs={},
|
||||
expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]},
|
||||
description="flatten_output=False: Nested output",
|
||||
use_auto_mock=True, # Use auto-mock to avoid sandbox service
|
||||
mock_config=_create_iteration_mock_config(),
|
||||
use_auto_mock=False, # Run code nodes directly
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock LLM node."""
|
||||
@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock agent node."""
|
||||
@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock tool node."""
|
||||
@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock knowledge retrieval node."""
|
||||
@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock HTTP request node."""
|
||||
@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock question classifier node."""
|
||||
@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock parameter extractor node."""
|
||||
@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock document extractor node."""
|
||||
@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _create_graph_engine(self, index: int, item: Any):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _create_graph_engine(self, start_at, root_node_id: str):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""Execute mock template transform node."""
|
||||
@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "1"
|
||||
return "mock-1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""Execute mock code node."""
|
||||
|
||||
@ -33,10 +33,6 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
||||
type_version_set: set[tuple[NodeType, str]] = set()
|
||||
|
||||
for cls in classes:
|
||||
# Only validate production node classes; skip test-defined subclasses and external helpers
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
if not module_name.startswith("core."):
|
||||
continue
|
||||
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
|
||||
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
|
||||
node_type = cls.node_type
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
import types
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
# Import concrete nodes we will assert on (numeric version path)
|
||||
from core.workflow.nodes.variable_assigner.v1.node import (
|
||||
VariableAssignerNode as VariableAssignerV1,
|
||||
)
|
||||
from core.workflow.nodes.variable_assigner.v2.node import (
|
||||
VariableAssignerNode as VariableAssignerV2,
|
||||
)
|
||||
|
||||
|
||||
def test_variable_assigner_latest_prefers_highest_numeric_version():
|
||||
# Act
|
||||
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
# Assert basic presence
|
||||
assert NodeType.VARIABLE_ASSIGNER in mapping
|
||||
va_versions = mapping[NodeType.VARIABLE_ASSIGNER]
|
||||
|
||||
# Both concrete versions must be present
|
||||
assert va_versions.get("1") is VariableAssignerV1
|
||||
assert va_versions.get("2") is VariableAssignerV2
|
||||
|
||||
# And latest should point to numerically-highest version ("2")
|
||||
assert va_versions.get("latest") is VariableAssignerV2
|
||||
|
||||
|
||||
def test_latest_prefers_highest_numeric_version():
|
||||
# Arrange: define two ephemeral subclasses with numeric versions under a NodeType
|
||||
# that has no concrete implementations in production to avoid interference.
|
||||
class _Version1(Node[BaseNodeData]): # type: ignore[misc]
|
||||
node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR
|
||||
|
||||
def init_node_data(self, data):
|
||||
pass
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return None
|
||||
|
||||
def _get_retry_config(self):
|
||||
return types.SimpleNamespace() # not used
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return "version1"
|
||||
|
||||
def _get_description(self):
|
||||
return None
|
||||
|
||||
def _get_default_value_dict(self):
|
||||
return {}
|
||||
|
||||
def get_base_node_data(self):
|
||||
return types.SimpleNamespace(title="version1")
|
||||
|
||||
class _Version2(_Version1): # type: ignore[misc]
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return "version2"
|
||||
|
||||
# Act: build a fresh mapping (it should now see our ephemeral subclasses)
|
||||
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
# Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version
|
||||
assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping
|
||||
legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR]
|
||||
|
||||
assert legacy_versions.get("1") is _Version1
|
||||
assert legacy_versions.get("2") is _Version2
|
||||
assert legacy_versions.get("latest") is _Version2
|
||||
@ -471,8 +471,8 @@ class TestCodeNodeInitialization:
|
||||
|
||||
assert node._get_description() is None
|
||||
|
||||
def test_node_data_property(self):
|
||||
"""Test node_data property returns node data."""
|
||||
def test_get_base_node_data(self):
|
||||
"""Test get_base_node_data returns node data."""
|
||||
node = CodeNode.__new__(CodeNode)
|
||||
node._node_data = CodeNodeData(
|
||||
title="Base Test",
|
||||
@ -482,7 +482,7 @@ class TestCodeNodeInitialization:
|
||||
outputs={},
|
||||
)
|
||||
|
||||
result = node.node_data
|
||||
result = node.get_base_node_data()
|
||||
|
||||
assert result == node._node_data
|
||||
assert result.title == "Base Test"
|
||||
|
||||
@ -240,8 +240,8 @@ class TestIterationNodeInitialization:
|
||||
|
||||
assert node._get_description() == "This is a description"
|
||||
|
||||
def test_node_data_property(self):
|
||||
"""Test node_data property returns node data."""
|
||||
def test_get_base_node_data(self):
|
||||
"""Test get_base_node_data returns node data."""
|
||||
node = IterationNode.__new__(IterationNode)
|
||||
node._node_data = IterationNodeData(
|
||||
title="Base Test",
|
||||
@ -249,7 +249,7 @@ class TestIterationNodeInitialization:
|
||||
output_selector=["y"],
|
||||
)
|
||||
|
||||
result = node.node_data
|
||||
result = node.get_base_node_data()
|
||||
|
||||
assert result == node._node_data
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]):
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
return "sample-test"
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -263,62 +263,3 @@ class TestResponseUnmodified:
|
||||
)
|
||||
assert response.text == _RESPONSE_NEEDLE
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestRequestFinishedInfoAccessLine:
|
||||
def test_info_access_log_includes_method_path_status_duration_trace_id(self, monkeypatch, caplog):
|
||||
"""Ensure INFO access line contains expected fields with computed duration and trace id."""
|
||||
app = _get_test_app()
|
||||
# Push a real request context so flask.request and g are available
|
||||
with app.test_request_context("/foo", method="GET"):
|
||||
# Seed start timestamp via the extension's own start hook and control perf_counter deterministically
|
||||
seq = iter([100.0, 100.123456])
|
||||
monkeypatch.setattr(ext_request_logging.time, "perf_counter", lambda: next(seq))
|
||||
# Provide a deterministic trace id
|
||||
monkeypatch.setattr(
|
||||
ext_request_logging,
|
||||
"get_trace_id_from_otel_context",
|
||||
lambda: "trace-xyz",
|
||||
)
|
||||
# Simulate request_started to record start timestamp on g
|
||||
ext_request_logging._log_request_started(app)
|
||||
|
||||
# Capture logs from the real logger at INFO level only (skip DEBUG branch)
|
||||
caplog.set_level(logging.INFO, logger=ext_request_logging.__name__)
|
||||
response = Response(json.dumps({"ok": True}), mimetype="application/json", status=200)
|
||||
_log_request_finished(app, response)
|
||||
|
||||
# Verify a single INFO record with the five fields in order
|
||||
info_records = [rec for rec in caplog.records if rec.levelno == logging.INFO]
|
||||
assert len(info_records) == 1
|
||||
msg = info_records[0].getMessage()
|
||||
# Expected format: METHOD PATH STATUS DURATION_MS TRACE_ID
|
||||
assert "GET" in msg
|
||||
assert "/foo" in msg
|
||||
assert "200" in msg
|
||||
assert "123.456" in msg # rounded to 3 decimals
|
||||
assert "trace-xyz" in msg
|
||||
|
||||
def test_info_access_log_uses_dash_without_start_timestamp(self, monkeypatch, caplog):
|
||||
app = _get_test_app()
|
||||
with app.test_request_context("/bar", method="POST"):
|
||||
# No g.__request_started_ts set -> duration should be '-'
|
||||
monkeypatch.setattr(
|
||||
ext_request_logging,
|
||||
"get_trace_id_from_otel_context",
|
||||
lambda: "tid-no-start",
|
||||
)
|
||||
caplog.set_level(logging.INFO, logger=ext_request_logging.__name__)
|
||||
response = Response("OK", mimetype="text/plain", status=204)
|
||||
_log_request_finished(app, response)
|
||||
|
||||
info_records = [rec for rec in caplog.records if rec.levelno == logging.INFO]
|
||||
assert len(info_records) == 1
|
||||
msg = info_records[0].getMessage()
|
||||
assert "POST" in msg
|
||||
assert "/bar" in msg
|
||||
assert "204" in msg
|
||||
# Duration placeholder
|
||||
# The fields are space separated; ensure a standalone '-' appears
|
||||
assert " - " in msg or msg.endswith(" -")
|
||||
assert "tid-no-start" in msg
|
||||
|
||||
@ -1,718 +0,0 @@
|
||||
"""
|
||||
Comprehensive unit tests for AudioService.
|
||||
|
||||
This test suite provides complete coverage of audio processing operations in Dify,
|
||||
following TDD principles with the Arrange-Act-Assert pattern.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### 1. Speech-to-Text (ASR) Operations (TestAudioServiceASR)
|
||||
Tests audio transcription functionality:
|
||||
- Successful transcription for different app modes
|
||||
- File validation (size, type, presence)
|
||||
- Feature flag validation (speech-to-text enabled)
|
||||
- Error handling for various failure scenarios
|
||||
- Model instance availability checks
|
||||
|
||||
### 2. Text-to-Speech (TTS) Operations (TestAudioServiceTTS)
|
||||
Tests text-to-audio conversion:
|
||||
- TTS with text input
|
||||
- TTS with message ID
|
||||
- Voice selection (explicit and default)
|
||||
- Feature flag validation (text-to-speech enabled)
|
||||
- Draft workflow handling
|
||||
- Streaming response handling
|
||||
- Error handling for missing/invalid inputs
|
||||
|
||||
### 3. TTS Voice Listing (TestAudioServiceTTSVoices)
|
||||
Tests available voice retrieval:
|
||||
- Get available voices for a tenant
|
||||
- Language filtering
|
||||
- Error handling for missing provider
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- **Mocking Strategy**: All external dependencies (ModelManager, db, FileStorage) are mocked
|
||||
for fast, isolated unit tests
|
||||
- **Factory Pattern**: AudioServiceTestDataFactory provides consistent test data
|
||||
- **Fixtures**: Mock objects are configured per test method
|
||||
- **Assertions**: Each test verifies return values, side effects, and error conditions
|
||||
|
||||
## Key Concepts
|
||||
|
||||
**Audio Formats:**
|
||||
- Supported: mp3, wav, m4a, flac, ogg, opus, webm
|
||||
- File size limit: 30 MB
|
||||
|
||||
**App Modes:**
|
||||
- ADVANCED_CHAT/WORKFLOW: Use workflow features
|
||||
- CHAT/COMPLETION: Use app_model_config
|
||||
|
||||
**Feature Flags:**
|
||||
- speech_to_text: Enables ASR functionality
|
||||
- text_to_speech: Enables TTS functionality
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from models.enums import MessageStatus
|
||||
from models.model import App, AppMode, AppModelConfig, Message
|
||||
from models.workflow import Workflow
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
ProviderNotSupportTextToSpeechServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
class AudioServiceTestDataFactory:
|
||||
"""
|
||||
Factory for creating test data and mock objects.
|
||||
|
||||
Provides reusable methods to create consistent mock objects for testing
|
||||
audio-related operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(
|
||||
app_id: str = "app-123",
|
||||
mode: AppMode = AppMode.CHAT,
|
||||
tenant_id: str = "tenant-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock App object.
|
||||
|
||||
Args:
|
||||
app_id: Unique identifier for the app
|
||||
mode: App mode (CHAT, ADVANCED_CHAT, WORKFLOW, etc.)
|
||||
tenant_id: Tenant identifier
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock App object with specified attributes
|
||||
"""
|
||||
app = create_autospec(App, instance=True)
|
||||
app.id = app_id
|
||||
app.mode = mode
|
||||
app.tenant_id = tenant_id
|
||||
app.workflow = kwargs.get("workflow")
|
||||
app.app_model_config = kwargs.get("app_model_config")
|
||||
for key, value in kwargs.items():
|
||||
setattr(app, key, value)
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_mock(features_dict: dict | None = None, **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock Workflow object.
|
||||
|
||||
Args:
|
||||
features_dict: Dictionary of workflow features
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Workflow object with specified attributes
|
||||
"""
|
||||
workflow = create_autospec(Workflow, instance=True)
|
||||
workflow.features_dict = features_dict or {}
|
||||
for key, value in kwargs.items():
|
||||
setattr(workflow, key, value)
|
||||
return workflow
|
||||
|
||||
@staticmethod
|
||||
def create_app_model_config_mock(
|
||||
speech_to_text_dict: dict | None = None,
|
||||
text_to_speech_dict: dict | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock AppModelConfig object.
|
||||
|
||||
Args:
|
||||
speech_to_text_dict: Speech-to-text configuration
|
||||
text_to_speech_dict: Text-to-speech configuration
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock AppModelConfig object with specified attributes
|
||||
"""
|
||||
config = create_autospec(AppModelConfig, instance=True)
|
||||
config.speech_to_text_dict = speech_to_text_dict or {"enabled": False}
|
||||
config.text_to_speech_dict = text_to_speech_dict or {"enabled": False}
|
||||
for key, value in kwargs.items():
|
||||
setattr(config, key, value)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def create_file_storage_mock(
|
||||
filename: str = "test.mp3",
|
||||
mimetype: str = "audio/mp3",
|
||||
content: bytes = b"fake audio content",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock FileStorage object.
|
||||
|
||||
Args:
|
||||
filename: Name of the file
|
||||
mimetype: MIME type of the file
|
||||
content: File content as bytes
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock FileStorage object with specified attributes
|
||||
"""
|
||||
file = Mock(spec=FileStorage)
|
||||
file.filename = filename
|
||||
file.mimetype = mimetype
|
||||
file.read = Mock(return_value=content)
|
||||
for key, value in kwargs.items():
|
||||
setattr(file, key, value)
|
||||
return file
|
||||
|
||||
@staticmethod
|
||||
def create_message_mock(
|
||||
message_id: str = "msg-123",
|
||||
answer: str = "Test answer",
|
||||
status: MessageStatus = MessageStatus.NORMAL,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Message object.
|
||||
|
||||
Args:
|
||||
message_id: Unique identifier for the message
|
||||
answer: Message answer text
|
||||
status: Message status
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Message object with specified attributes
|
||||
"""
|
||||
message = create_autospec(Message, instance=True)
|
||||
message.id = message_id
|
||||
message.answer = answer
|
||||
message.status = status
|
||||
for key, value in kwargs.items():
|
||||
setattr(message, key, value)
|
||||
return message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return AudioServiceTestDataFactory
|
||||
|
||||
|
||||
class TestAudioServiceASR:
|
||||
"""Test speech-to-text (ASR) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
|
||||
"""Test successful ASR transcription in CHAT mode."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_speech2text.return_value = "Transcribed text"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_asr(app_model=app, file=file, end_user="user-123")
|
||||
|
||||
# Assert
|
||||
assert result == {"text": "Transcribed text"}
|
||||
mock_model_instance.invoke_speech2text.assert_called_once()
|
||||
call_args = mock_model_instance.invoke_speech2text.call_args
|
||||
assert call_args.kwargs["user"] == "user-123"
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
|
||||
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
|
||||
# Arrange
|
||||
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": True}})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
workflow=workflow,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_speech2text.return_value = "Workflow transcribed text"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
# Assert
|
||||
assert result == {"text": "Workflow transcribed text"}
|
||||
|
||||
def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory):
|
||||
"""Test that ASR raises error when speech-to-text is disabled in CHAT mode."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": False})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(self, factory):
|
||||
"""Test that ASR raises error when speech-to-text is disabled in WORKFLOW mode."""
|
||||
# Arrange
|
||||
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": False}})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.WORKFLOW,
|
||||
workflow=workflow,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_workflow_missing(self, factory):
|
||||
"""Test that ASR raises error when workflow is missing in WORKFLOW mode."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.WORKFLOW,
|
||||
workflow=None,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory):
|
||||
"""Test that ASR raises error when no file is uploaded."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoAudioUploadedServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=None)
|
||||
|
||||
def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory):
|
||||
"""Test that ASR raises error for unsupported audio file types."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock(mimetype="video/mp4")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(UnsupportedAudioTypeServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_for_large_file(self, factory):
|
||||
"""Test that ASR raises error when file exceeds size limit (30MB)."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
# Create file larger than 30MB
|
||||
large_content = b"x" * (31 * 1024 * 1024)
|
||||
file = factory.create_file_storage_mock(content=large_content)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
"""Test that ASR raises error when no model instance is available."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager to return None
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
mock_model_manager.get_default_model_instance.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ProviderNotSupportSpeechToTextServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
|
||||
class TestAudioServiceTTS:
|
||||
"""Test text-to-speech (TTS) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
|
||||
"""Test successful TTS with text input."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True, "voice": "en-US-Neural"}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Hello world",
|
||||
voice="en-US-Neural",
|
||||
end_user="user-123",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio data"
|
||||
mock_model_instance.invoke_tts.assert_called_once_with(
|
||||
content_text="Hello world",
|
||||
user="user-123",
|
||||
tenant_id=app.tenant_id,
|
||||
voice="en-US-Neural",
|
||||
)
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
|
||||
"""Test successful TTS with message ID."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True, "voice": "en-US-Neural"}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
message = factory.create_message_mock(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
answer="Message answer text",
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio from message"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio from message"
|
||||
mock_model_instance.invoke_tts.assert_called_once()
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
|
||||
"""Test TTS uses default voice when none specified."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True, "voice": "default-voice"}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio data"
|
||||
# Verify default voice was used
|
||||
call_args = mock_model_instance.invoke_tts.call_args
|
||||
assert call_args.kwargs["voice"] == "default-voice"
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
|
||||
"""Test TTS gets first available voice when none is configured."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True} # No voice specified
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = [{"value": "auto-voice"}]
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Test",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"audio data"
|
||||
call_args = mock_model_instance.invoke_tts.call_args
|
||||
assert call_args.kwargs["voice"] == "auto-voice"
|
||||
|
||||
@patch("services.audio_service.WorkflowService")
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_workflow_mode_with_draft(
|
||||
self, mock_model_manager_class, mock_workflow_service_class, factory
|
||||
):
|
||||
"""Test TTS in WORKFLOW mode with draft workflow."""
|
||||
# Arrange
|
||||
draft_workflow = factory.create_workflow_mock(
|
||||
features_dict={"text_to_speech": {"enabled": True, "voice": "draft-voice"}}
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.WORKFLOW,
|
||||
)
|
||||
|
||||
# Mock WorkflowService
|
||||
mock_workflow_service = MagicMock()
|
||||
mock_workflow_service_class.return_value = mock_workflow_service
|
||||
mock_workflow_service.get_draft_workflow.return_value = draft_workflow
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"draft audio"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
text="Draft test",
|
||||
is_draft=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"draft audio"
|
||||
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app)
|
||||
|
||||
def test_transcript_tts_raises_error_when_text_missing(self, factory):
|
||||
"""Test that TTS raises error when text is missing."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Text is required"):
|
||||
AudioService.transcript_tts(app_model=app, text=None)
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None for invalid message ID format."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="invalid-uuid",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message doesn't exist."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message answer is empty."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
message = factory.create_message_mock(
|
||||
answer="",
|
||||
status=MessageStatus.NORMAL,
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS raises error when no voices are available."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
text_to_speech_dict={"enabled": True} # No voice specified
|
||||
)
|
||||
app = factory.create_app_mock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = [] # No voices available
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Sorry, no voice available"):
|
||||
AudioService.transcript_tts(app_model=app, text="Test")
|
||||
|
||||
|
||||
class TestAudioServiceTTSVoices:
|
||||
"""Test TTS voice listing operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
|
||||
"""Test successful retrieval of TTS voices."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
language = "en-US"
|
||||
|
||||
expected_voices = [
|
||||
{"name": "Voice 1", "value": "voice-1"},
|
||||
{"name": "Voice 2", "value": "voice-2"},
|
||||
]
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = expected_voices
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
||||
# Assert
|
||||
assert result == expected_voices
|
||||
mock_model_instance.get_tts_voices.assert_called_once_with(language)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS voices raises error when no model instance is available."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
language = "en-US"
|
||||
|
||||
# Mock ModelManager to return None
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
mock_model_manager.get_default_model_instance.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ProviderNotSupportTextToSpeechServiceError):
|
||||
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS voices propagates exceptions from model instance."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
language = "en-US"
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.side_effect = RuntimeError("Model error")
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="Model error"):
|
||||
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
@ -1,177 +0,0 @@
|
||||
import types
|
||||
from unittest.mock import Mock, create_autospec
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from services.dataset_service import DocumentService, SegmentService
|
||||
|
||||
|
||||
class FakeLock:
|
||||
"""Lock that always fails on enter with LockNotOwnedError."""
|
||||
|
||||
def __enter__(self):
|
||||
raise LockNotOwnedError("simulated")
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
# Normal contextmanager signature; return False so exceptions propagate
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_current_user(monkeypatch):
|
||||
user = create_autospec(Account, instance=True)
|
||||
user.id = "user-1"
|
||||
user.current_tenant_id = "tenant-1"
|
||||
monkeypatch.setattr("services.dataset_service.current_user", user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_features(monkeypatch):
|
||||
"""Features.billing.enabled == False to skip quota logic."""
|
||||
features = types.SimpleNamespace(
|
||||
billing=types.SimpleNamespace(enabled=False, subscription=types.SimpleNamespace(plan="ENTERPRISE")),
|
||||
documents_upload_quota=types.SimpleNamespace(limit=10_000, size=0),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"services.dataset_service.FeatureService.get_features",
|
||||
lambda tenant_id: features,
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_lock(monkeypatch):
|
||||
"""Patch redis_client.lock to always raise LockNotOwnedError on enter."""
|
||||
|
||||
def _fake_lock(name, timeout=None, *args, **kwargs):
|
||||
return FakeLock()
|
||||
|
||||
# DatasetService imports redis_client directly from extensions.ext_redis
|
||||
monkeypatch.setattr("services.dataset_service.redis_client.lock", _fake_lock)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Knowledge Pipeline document creation (save_document_with_dataset_id)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_save_document_with_dataset_id_ignores_lock_not_owned(
|
||||
monkeypatch,
|
||||
fake_current_user,
|
||||
fake_features,
|
||||
fake_lock,
|
||||
):
|
||||
# Arrange
|
||||
dataset = create_autospec(Dataset, instance=True)
|
||||
dataset.id = "ds-1"
|
||||
dataset.tenant_id = fake_current_user.current_tenant_id
|
||||
dataset.data_source_type = "upload_file"
|
||||
dataset.indexing_technique = "high_quality" # so we skip re-initialization branch
|
||||
|
||||
# Minimal knowledge_config stub that satisfies pre-lock code
|
||||
info_list = types.SimpleNamespace(data_source_type="upload_file")
|
||||
data_source = types.SimpleNamespace(info_list=info_list)
|
||||
knowledge_config = types.SimpleNamespace(
|
||||
doc_form="qa_model",
|
||||
original_document_id=None, # go into "new document" branch
|
||||
data_source=data_source,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model=None,
|
||||
embedding_model_provider=None,
|
||||
retrieval_model=None,
|
||||
process_rule=None,
|
||||
duplicate=False,
|
||||
doc_language="en",
|
||||
)
|
||||
|
||||
account = fake_current_user
|
||||
|
||||
# Avoid touching real doc_form logic
|
||||
monkeypatch.setattr("services.dataset_service.DatasetService.check_doc_form", lambda *a, **k: None)
|
||||
# Avoid real DB interactions
|
||||
monkeypatch.setattr("services.dataset_service.db", Mock())
|
||||
|
||||
# Act: this would hit the redis lock, whose __enter__ raises LockNotOwnedError.
|
||||
# Our implementation should catch it and still return (documents, batch).
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# We mainly care that:
|
||||
# - No exception is raised
|
||||
# - The function returns a sensible tuple
|
||||
assert isinstance(documents, list)
|
||||
assert isinstance(batch, str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Single-segment creation (add_segment)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_add_segment_ignores_lock_not_owned(
|
||||
monkeypatch,
|
||||
fake_current_user,
|
||||
fake_lock,
|
||||
):
|
||||
# Arrange
|
||||
dataset = create_autospec(Dataset, instance=True)
|
||||
dataset.id = "ds-1"
|
||||
dataset.tenant_id = fake_current_user.current_tenant_id
|
||||
dataset.indexing_technique = "economy" # skip embedding/token calculation branch
|
||||
|
||||
document = create_autospec(Document, instance=True)
|
||||
document.id = "doc-1"
|
||||
document.dataset_id = dataset.id
|
||||
document.word_count = 0
|
||||
document.doc_form = "qa_model"
|
||||
|
||||
# Minimal args required by add_segment
|
||||
args = {
|
||||
"content": "question text",
|
||||
"answer": "answer text",
|
||||
"keywords": ["k1", "k2"],
|
||||
}
|
||||
|
||||
# Avoid real DB operations
|
||||
db_mock = Mock()
|
||||
db_mock.session = Mock()
|
||||
monkeypatch.setattr("services.dataset_service.db", db_mock)
|
||||
monkeypatch.setattr("services.dataset_service.VectorService", Mock())
|
||||
|
||||
# Act
|
||||
result = SegmentService.create_segment(args=args, document=document, dataset=dataset)
|
||||
|
||||
# Assert
|
||||
# Under LockNotOwnedError except, add_segment should swallow the error and return None.
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Multi-segment creation (multi_create_segment)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_multi_create_segment_ignores_lock_not_owned(
|
||||
monkeypatch,
|
||||
fake_current_user,
|
||||
fake_lock,
|
||||
):
|
||||
# Arrange
|
||||
dataset = create_autospec(Dataset, instance=True)
|
||||
dataset.id = "ds-1"
|
||||
dataset.tenant_id = fake_current_user.current_tenant_id
|
||||
dataset.indexing_technique = "economy" # again, skip high_quality path
|
||||
|
||||
document = create_autospec(Document, instance=True)
|
||||
document.id = "doc-1"
|
||||
document.dataset_id = dataset.id
|
||||
document.word_count = 0
|
||||
document.doc_form = "qa_model"
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,440 +0,0 @@
|
||||
"""
|
||||
Comprehensive unit tests for RecommendedAppService.
|
||||
|
||||
This test suite provides complete coverage of recommended app operations in Dify,
|
||||
following TDD principles with the Arrange-Act-Assert pattern.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### 1. Get Recommended Apps and Categories (TestRecommendedAppServiceGetApps)
|
||||
Tests fetching recommended apps with categories:
|
||||
- Successful retrieval with recommended apps
|
||||
- Fallback to builtin when no recommended apps
|
||||
- Different language support
|
||||
- Factory mode selection (remote, builtin, db)
|
||||
- Empty result handling
|
||||
|
||||
### 2. Get Recommend App Detail (TestRecommendedAppServiceGetDetail)
|
||||
Tests fetching individual app details:
|
||||
- Successful app detail retrieval
|
||||
- Different factory modes
|
||||
- App not found scenarios
|
||||
- Language-specific details
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- **Mocking Strategy**: All external dependencies (dify_config, RecommendAppRetrievalFactory)
|
||||
are mocked for fast, isolated unit tests
|
||||
- **Factory Pattern**: Tests verify correct factory selection based on mode
|
||||
- **Fixtures**: Mock objects are configured per test method
|
||||
- **Assertions**: Each test verifies return values and factory method calls
|
||||
|
||||
## Key Concepts
|
||||
|
||||
**Factory Modes:**
|
||||
- remote: Fetch from remote API
|
||||
- builtin: Use built-in templates
|
||||
- db: Fetch from database
|
||||
|
||||
**Fallback Logic:**
|
||||
- If remote/db returns no apps, fallback to builtin en-US templates
|
||||
- Ensures users always see some recommended apps
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
class RecommendedAppServiceTestDataFactory:
|
||||
"""
|
||||
Factory for creating test data and mock objects.
|
||||
|
||||
Provides reusable methods to create consistent mock objects for testing
|
||||
recommended app operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_recommended_apps_response(
|
||||
recommended_apps: list[dict] | None = None,
|
||||
categories: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a mock response for recommended apps.
|
||||
|
||||
Args:
|
||||
recommended_apps: List of recommended app dictionaries
|
||||
categories: List of category names
|
||||
|
||||
Returns:
|
||||
Dictionary with recommended_apps and categories
|
||||
"""
|
||||
if recommended_apps is None:
|
||||
recommended_apps = [
|
||||
{
|
||||
"id": "app-1",
|
||||
"name": "Test App 1",
|
||||
"description": "Test description 1",
|
||||
"category": "productivity",
|
||||
},
|
||||
{
|
||||
"id": "app-2",
|
||||
"name": "Test App 2",
|
||||
"description": "Test description 2",
|
||||
"category": "communication",
|
||||
},
|
||||
]
|
||||
if categories is None:
|
||||
categories = ["productivity", "communication", "utilities"]
|
||||
|
||||
return {
|
||||
"recommended_apps": recommended_apps,
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_app_detail_response(
|
||||
app_id: str = "app-123",
|
||||
name: str = "Test App",
|
||||
description: str = "Test description",
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a mock response for app detail.
|
||||
|
||||
Args:
|
||||
app_id: App identifier
|
||||
name: App name
|
||||
description: App description
|
||||
**kwargs: Additional fields
|
||||
|
||||
Returns:
|
||||
Dictionary with app details
|
||||
"""
|
||||
detail = {
|
||||
"id": app_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"category": kwargs.get("category", "productivity"),
|
||||
"icon": kwargs.get("icon", "🚀"),
|
||||
"model_config": kwargs.get("model_config", {}),
|
||||
}
|
||||
detail.update(kwargs)
|
||||
return detail
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return RecommendedAppServiceTestDataFactory
|
||||
|
||||
|
||||
class TestRecommendedAppServiceGetApps:
|
||||
"""Test get_recommended_apps_and_categories operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of recommended apps when apps are returned."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
|
||||
expected_response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock factory and retrieval instance
|
||||
mock_retrieval_instance = MagicMock()
|
||||
mock_retrieval_instance.get_recommended_apps_and_categories.return_value = expected_response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_retrieval_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
assert len(result["recommended_apps"]) == 2
|
||||
assert len(result["categories"]) == 3
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
|
||||
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback to builtin when no recommended apps are returned."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
|
||||
# Remote returns empty recommended_apps
|
||||
empty_response = {"recommended_apps": [], "categories": []}
|
||||
|
||||
# Builtin fallback response
|
||||
builtin_response = factory.create_recommended_apps_response(
|
||||
recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}]
|
||||
)
|
||||
|
||||
# Mock remote retrieval instance (returns empty)
|
||||
mock_remote_instance = MagicMock()
|
||||
mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response
|
||||
|
||||
mock_remote_factory = MagicMock()
|
||||
mock_remote_factory.return_value = mock_remote_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_remote_factory
|
||||
|
||||
# Mock builtin retrieval instance
|
||||
mock_builtin_instance = MagicMock()
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
|
||||
|
||||
# Assert
|
||||
assert result == builtin_response
|
||||
assert len(result["recommended_apps"]) == 1
|
||||
assert result["recommended_apps"][0]["id"] == "builtin-1"
|
||||
# Verify fallback was called with en-US (hardcoded)
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback when recommended_apps key is None."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
|
||||
|
||||
# Response with None recommended_apps
|
||||
none_response = {"recommended_apps": None, "categories": ["test"]}
|
||||
|
||||
# Builtin fallback response
|
||||
builtin_response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock db retrieval instance (returns None)
|
||||
mock_db_instance = MagicMock()
|
||||
mock_db_instance.get_recommended_apps_and_categories.return_value = none_response
|
||||
|
||||
mock_db_factory = MagicMock()
|
||||
mock_db_factory.return_value = mock_db_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_db_factory
|
||||
|
||||
# Mock builtin retrieval instance
|
||||
mock_builtin_instance = MagicMock()
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
assert result == builtin_response
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
|
||||
"""Test retrieval with different language codes."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
|
||||
languages = ["en-US", "zh-CN", "ja-JP", "fr-FR"]
|
||||
|
||||
for language in languages:
|
||||
# Create language-specific response
|
||||
lang_response = factory.create_recommended_apps_response(
|
||||
recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}]
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(language)
|
||||
|
||||
# Assert
|
||||
assert result["recommended_apps"][0]["id"] == f"app-{language}"
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that correct factory is selected based on mode."""
|
||||
# Arrange
|
||||
modes = ["remote", "builtin", "db"]
|
||||
|
||||
for mode in modes:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
|
||||
response = factory.create_recommended_apps_response()
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = response
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
|
||||
# Assert
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
|
||||
class TestRecommendedAppServiceGetDetail:
|
||||
"""Test get_recommend_app_detail operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of app detail."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "app-123"
|
||||
|
||||
expected_detail = factory.create_app_detail_response(
|
||||
app_id=app_id,
|
||||
name="Productivity App",
|
||||
description="A great productivity app",
|
||||
category="productivity",
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = expected_detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected_detail
|
||||
assert result["id"] == app_id
|
||||
assert result["name"] == "Productivity App"
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail retrieval with different factory modes."""
|
||||
# Arrange
|
||||
modes = ["remote", "builtin", "db"]
|
||||
app_id = "test-app"
|
||||
|
||||
for mode in modes:
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode
|
||||
|
||||
detail = factory.create_app_detail_response(app_id=app_id, name=f"App from {mode}")
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result["name"] == f"App from {mode}"
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that None is returned when app is not found."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "nonexistent-app"
|
||||
|
||||
# Mock retrieval instance returning None
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = None
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
|
||||
"""Test handling of empty dict response."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin"
|
||||
app_id = "app-empty"
|
||||
|
||||
# Mock retrieval instance returning empty dict
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = {}
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail with complex model configuration."""
|
||||
# Arrange
|
||||
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
|
||||
app_id = "complex-app"
|
||||
|
||||
complex_model_config = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
"parameters": {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
expected_detail = factory.create_app_detail_response(
|
||||
app_id=app_id,
|
||||
name="Complex App",
|
||||
model_config=complex_model_config,
|
||||
workflows=["workflow-1", "workflow-2"],
|
||||
tools=["tool-1", "tool-2", "tool-3"],
|
||||
)
|
||||
|
||||
# Mock retrieval instance
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.get_recommend_app_detail.return_value = expected_detail
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value = mock_instance
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
# Act
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
|
||||
# Assert
|
||||
assert result["model_config"] == complex_model_config
|
||||
assert len(result["workflows"]) == 2
|
||||
assert len(result["tools"]) == 3
|
||||
@ -1,626 +0,0 @@
|
||||
"""
|
||||
Comprehensive unit tests for SavedMessageService.
|
||||
|
||||
This test suite provides complete coverage of saved message operations in Dify,
|
||||
following TDD principles with the Arrange-Act-Assert pattern.
|
||||
|
||||
## Test Coverage
|
||||
|
||||
### 1. Pagination (TestSavedMessageServicePagination)
|
||||
Tests saved message listing and pagination:
|
||||
- Pagination with valid user (Account and EndUser)
|
||||
- Pagination without user raises ValueError
|
||||
- Pagination with last_id parameter
|
||||
- Empty results when no saved messages exist
|
||||
- Integration with MessageService pagination
|
||||
|
||||
### 2. Save Operations (TestSavedMessageServiceSave)
|
||||
Tests saving messages:
|
||||
- Save message for Account user
|
||||
- Save message for EndUser
|
||||
- Save without user (no-op)
|
||||
- Prevent duplicate saves (idempotent)
|
||||
- Message validation through MessageService
|
||||
|
||||
### 3. Delete Operations (TestSavedMessageServiceDelete)
|
||||
Tests deleting saved messages:
|
||||
- Delete saved message for Account user
|
||||
- Delete saved message for EndUser
|
||||
- Delete without user (no-op)
|
||||
- Delete non-existent saved message (no-op)
|
||||
- Proper database cleanup
|
||||
|
||||
## Testing Approach
|
||||
|
||||
- **Mocking Strategy**: All external dependencies (database, MessageService) are mocked
|
||||
for fast, isolated unit tests
|
||||
- **Factory Pattern**: SavedMessageServiceTestDataFactory provides consistent test data
|
||||
- **Fixtures**: Mock objects are configured per test method
|
||||
- **Assertions**: Each test verifies return values and side effects
|
||||
(database operations, method calls)
|
||||
|
||||
## Key Concepts
|
||||
|
||||
**User Types:**
|
||||
- Account: Workspace members (console users)
|
||||
- EndUser: API users (end users)
|
||||
|
||||
**Saved Messages:**
|
||||
- Users can save messages for later reference
|
||||
- Each user has their own saved message list
|
||||
- Saving is idempotent (duplicate saves ignored)
|
||||
- Deletion is safe (non-existent deletes ignored)
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.model import App, EndUser, Message
|
||||
from models.web import SavedMessage
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageServiceTestDataFactory:
|
||||
"""
|
||||
Factory for creating test data and mock objects.
|
||||
|
||||
Provides reusable methods to create consistent mock objects for testing
|
||||
saved message operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_account_mock(account_id: str = "account-123", **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock Account object.
|
||||
|
||||
Args:
|
||||
account_id: Unique identifier for the account
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Account object with specified attributes
|
||||
"""
|
||||
account = create_autospec(Account, instance=True)
|
||||
account.id = account_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(account, key, value)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_end_user_mock(user_id: str = "user-123", **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock EndUser object.
|
||||
|
||||
Args:
|
||||
user_id: Unique identifier for the end user
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock EndUser object with specified attributes
|
||||
"""
|
||||
user = create_autospec(EndUser, instance=True)
|
||||
user.id = user_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(user, key, value)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock:
|
||||
"""
|
||||
Create a mock App object.
|
||||
|
||||
Args:
|
||||
app_id: Unique identifier for the app
|
||||
tenant_id: Tenant/workspace identifier
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock App object with specified attributes
|
||||
"""
|
||||
app = create_autospec(App, instance=True)
|
||||
app.id = app_id
|
||||
app.tenant_id = tenant_id
|
||||
app.name = kwargs.get("name", "Test App")
|
||||
app.mode = kwargs.get("mode", "chat")
|
||||
for key, value in kwargs.items():
|
||||
setattr(app, key, value)
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_message_mock(
|
||||
message_id: str = "msg-123",
|
||||
app_id: str = "app-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Message object.
|
||||
|
||||
Args:
|
||||
message_id: Unique identifier for the message
|
||||
app_id: Associated app identifier
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock Message object with specified attributes
|
||||
"""
|
||||
message = create_autospec(Message, instance=True)
|
||||
message.id = message_id
|
||||
message.app_id = app_id
|
||||
message.query = kwargs.get("query", "Test query")
|
||||
message.answer = kwargs.get("answer", "Test answer")
|
||||
message.created_at = kwargs.get("created_at", datetime.now(UTC))
|
||||
for key, value in kwargs.items():
|
||||
setattr(message, key, value)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def create_saved_message_mock(
|
||||
saved_message_id: str = "saved-123",
|
||||
app_id: str = "app-123",
|
||||
message_id: str = "msg-123",
|
||||
created_by: str = "user-123",
|
||||
created_by_role: str = "account",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock SavedMessage object.
|
||||
|
||||
Args:
|
||||
saved_message_id: Unique identifier for the saved message
|
||||
app_id: Associated app identifier
|
||||
message_id: Associated message identifier
|
||||
created_by: User who saved the message
|
||||
created_by_role: Role of the user ('account' or 'end_user')
|
||||
**kwargs: Additional attributes to set on the mock
|
||||
|
||||
Returns:
|
||||
Mock SavedMessage object with specified attributes
|
||||
"""
|
||||
saved_message = create_autospec(SavedMessage, instance=True)
|
||||
saved_message.id = saved_message_id
|
||||
saved_message.app_id = app_id
|
||||
saved_message.message_id = message_id
|
||||
saved_message.created_by = created_by
|
||||
saved_message.created_by_role = created_by_role
|
||||
saved_message.created_at = kwargs.get("created_at", datetime.now(UTC))
|
||||
for key, value in kwargs.items():
|
||||
setattr(saved_message, key, value)
|
||||
return saved_message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory():
|
||||
"""Provide the test data factory to all tests."""
|
||||
return SavedMessageServiceTestDataFactory
|
||||
|
||||
|
||||
class TestSavedMessageServicePagination:
|
||||
"""Test saved message pagination operations."""
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with an Account user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
|
||||
# Create saved messages for this user
|
||||
saved_messages = [
|
||||
factory.create_saved_message_mock(
|
||||
saved_message_id=f"saved-{i}",
|
||||
app_id=app.id,
|
||||
message_id=f"msg-{i}",
|
||||
created_by=user.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = saved_messages
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
mock_db_session.query.assert_called_once_with(SavedMessage)
|
||||
# Verify MessageService was called with correct message IDs
|
||||
mock_message_pagination.assert_called_once_with(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
include_ids=["msg-0", "msg-1", "msg-2"],
|
||||
)
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with an EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
|
||||
# Create saved messages for this end user
|
||||
saved_messages = [
|
||||
factory.create_saved_message_mock(
|
||||
saved_message_id=f"saved-{i}",
|
||||
app_id=app.id,
|
||||
message_id=f"msg-{i}",
|
||||
created_by=user.id,
|
||||
created_by_role="end_user",
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = saved_messages
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=False)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=10)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
# Verify correct role was used in query
|
||||
mock_message_pagination.assert_called_once_with(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=10,
|
||||
include_ids=["msg-0", "msg-1"],
|
||||
)
|
||||
|
||||
def test_pagination_without_user_raises_error(self, factory):
|
||||
"""Test that pagination without user raises ValueError."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="User is required"):
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20)
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with last_id parameter."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
last_id = "msg-last"
|
||||
|
||||
saved_messages = [
|
||||
factory.create_saved_message_mock(
|
||||
message_id=f"msg-{i}",
|
||||
app_id=app.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = saved_messages
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=10, has_more=True)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=last_id, limit=10)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
# Verify last_id was passed to MessageService
|
||||
mock_message_pagination.assert_called_once()
|
||||
call_args = mock_message_pagination.call_args
|
||||
assert call_args.kwargs["last_id"] == last_id
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination when user has no saved messages."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
|
||||
# Mock database query returning empty list
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.all.return_value = []
|
||||
|
||||
# Mock MessageService pagination response
|
||||
expected_pagination = InfiniteScrollPagination(data=[], limit=20, has_more=False)
|
||||
mock_message_pagination.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=user, last_id=None, limit=20)
|
||||
|
||||
# Assert
|
||||
assert result == expected_pagination
|
||||
# Verify MessageService was called with empty include_ids
|
||||
mock_message_pagination.assert_called_once_with(
|
||||
app_model=app,
|
||||
user=user,
|
||||
last_id=None,
|
||||
limit=20,
|
||||
include_ids=[],
|
||||
)
|
||||
|
||||
|
||||
class TestSavedMessageServiceSave:
|
||||
"""Test save message operations."""
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_message_for_account(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test saving a message for an Account user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message = factory.create_message_mock(message_id="msg-123", app_id=app.id)
|
||||
|
||||
# Mock database query - no existing saved message
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock MessageService.get_message
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.add.assert_called_once()
|
||||
saved_message = mock_db_session.add.call_args[0][0]
|
||||
assert saved_message.app_id == app.id
|
||||
assert saved_message.message_id == message.id
|
||||
assert saved_message.created_by == user.id
|
||||
assert saved_message.created_by_role == "account"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test saving a message for an EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
message = factory.create_message_mock(message_id="msg-456", app_id=app.id)
|
||||
|
||||
# Mock database query - no existing saved message
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock MessageService.get_message
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.add.assert_called_once()
|
||||
saved_message = mock_db_session.add.call_args[0][0]
|
||||
assert saved_message.app_id == app.id
|
||||
assert saved_message.message_id == message.id
|
||||
assert saved_message.created_by == user.id
|
||||
assert saved_message.created_by_role == "end_user"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_without_user_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that saving without user is a no-op."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=None, message_id="msg-123")
|
||||
|
||||
# Assert
|
||||
mock_db_session.query.assert_not_called()
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test that saving an already saved message is idempotent."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message_id = "msg-789"
|
||||
|
||||
# Mock database query - existing saved message found
|
||||
existing_saved = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = existing_saved
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert - no new saved message created
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
mock_get_message.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test that save validates message exists through MessageService."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message = factory.create_message_mock()
|
||||
|
||||
# Mock database query - no existing saved message
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock MessageService.get_message
|
||||
mock_get_message.return_value = message
|
||||
|
||||
# Act
|
||||
SavedMessageService.save(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
# Assert - MessageService.get_message was called for validation
|
||||
mock_get_message.assert_called_once_with(app_model=app, user=user, message_id=message.id)
|
||||
|
||||
|
||||
class TestSavedMessageServiceDelete:
|
||||
"""Test delete saved message operations."""
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_saved_message_for_account(self, mock_db_session, factory):
|
||||
"""Test deleting a saved message for an Account user."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message_id = "msg-123"
|
||||
|
||||
# Mock database query - existing saved message found
|
||||
saved_message = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = saved_message
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_saved_message_for_end_user(self, mock_db_session, factory):
|
||||
"""Test deleting a saved message for an EndUser."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_end_user_mock()
|
||||
message_id = "msg-456"
|
||||
|
||||
# Mock database query - existing saved message found
|
||||
saved_message = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user.id,
|
||||
created_by_role="end_user",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = saved_message
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_without_user_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that deleting without user is a no-op."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=None, message_id="msg-123")
|
||||
|
||||
# Assert
|
||||
mock_db_session.query.assert_not_called()
|
||||
mock_db_session.delete.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that deleting a non-existent saved message is a no-op."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user = factory.create_account_mock()
|
||||
message_id = "msg-nonexistent"
|
||||
|
||||
# Mock database query - no saved message found
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user, message_id=message_id)
|
||||
|
||||
# Assert - no deletion occurred
|
||||
mock_db_session.delete.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory):
|
||||
"""Test that delete only removes the user's own saved message."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
user1 = factory.create_account_mock(account_id="user-1")
|
||||
message_id = "msg-shared"
|
||||
|
||||
# Mock database query - finds user1's saved message
|
||||
saved_message = factory.create_saved_message_mock(
|
||||
app_id=app.id,
|
||||
message_id=message_id,
|
||||
created_by=user1.id,
|
||||
created_by_role="account",
|
||||
)
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = saved_message
|
||||
|
||||
# Act
|
||||
SavedMessageService.delete(app_model=app, user=user1, message_id=message_id)
|
||||
|
||||
# Assert - only user1's saved message is deleted
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
# Verify the query filters by user
|
||||
assert mock_query.where.called
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,9 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.api_entities import ToolApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
@ -299,154 +299,3 @@ class TestToolTransformService:
|
||||
param2 = result.parameters[1]
|
||||
assert param2.name == "param2"
|
||||
assert param2.label == "Runtime Param 2"
|
||||
|
||||
|
||||
class TestWorkflowProviderToUserProvider:
|
||||
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
|
||||
|
||||
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
workflow_app_id = "app_123"
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1", "label2"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_workflow_tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["label1", "label2"]
|
||||
assert result.is_team_authorization is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method without workflow_app_id
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1"],
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == ["label1"]
|
||||
|
||||
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
|
||||
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method with explicit None values
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=None,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
|
||||
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller with various fields
|
||||
workflow_app_id = "app_456"
|
||||
provider_id = "provider_456"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "another_author"
|
||||
mock_controller.entity.identity.name = "another_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(
|
||||
en_US="Another description", zh_Hans="Another description"
|
||||
)
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
|
||||
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.label = I18nObject(
|
||||
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["automation", "workflow"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify all fields are preserved correctly
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "another_author"
|
||||
assert result.name == "another_workflow_tool"
|
||||
assert result.description.en_US == "Another description"
|
||||
assert result.description.zh_Hans == "Another description"
|
||||
assert result.icon == {"type": "emoji", "content": "⚙️"}
|
||||
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
|
||||
assert result.label.en_US == "Another Workflow Tool"
|
||||
assert result.label.zh_Hans == "Another Workflow Tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["automation", "workflow"]
|
||||
assert result.masked_credentials == {}
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
4630
api/uv.lock
generated
4630
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@ -233,7 +233,7 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false
|
||||
|
||||
# Database type, supported values are `postgresql` and `mysql`
|
||||
DB_TYPE=postgresql
|
||||
# For MySQL, only `root` user is supported for now
|
||||
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=db_postgres
|
||||
@ -1076,10 +1076,24 @@ MAX_TREE_DEPTH=50
|
||||
# ------------------------------
|
||||
# Environment Variables for database Service
|
||||
# ------------------------------
|
||||
|
||||
# The name of the default postgres user.
|
||||
POSTGRES_USER=${DB_USERNAME}
|
||||
# The password for the default postgres user.
|
||||
POSTGRES_PASSWORD=${DB_PASSWORD}
|
||||
# The name of the default postgres database.
|
||||
POSTGRES_DB=${DB_DATABASE}
|
||||
# Postgres data directory
|
||||
PGDATA=/var/lib/postgresql/data/pgdata
|
||||
|
||||
# MySQL Default Configuration
|
||||
# The name of the default mysql user.
|
||||
MYSQL_USERNAME=${DB_USERNAME}
|
||||
# The password for the default mysql user.
|
||||
MYSQL_PASSWORD=${DB_PASSWORD}
|
||||
# The name of the default mysql database.
|
||||
MYSQL_DATABASE=${DB_DATABASE}
|
||||
# MySQL data directory
|
||||
MYSQL_HOST_VOLUME=./volumes/mysql/data
|
||||
|
||||
# ------------------------------
|
||||
|
||||
@ -139,9 +139,9 @@ services:
|
||||
- postgresql
|
||||
restart: always
|
||||
environment:
|
||||
POSTGRES_USER: ${DB_USERNAME:-postgres}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${DB_DATABASE:-dify}
|
||||
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-dify}
|
||||
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
|
||||
command: >
|
||||
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
|
||||
@ -161,7 +161,7 @@ services:
|
||||
"-h",
|
||||
"db_postgres",
|
||||
"-U",
|
||||
"${DB_USERNAME:-postgres}",
|
||||
"${PGUSER:-postgres}",
|
||||
"-d",
|
||||
"${DB_DATABASE:-dify}",
|
||||
]
|
||||
@ -176,8 +176,8 @@ services:
|
||||
- mysql
|
||||
restart: always
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${DB_DATABASE:-dify}
|
||||
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
|
||||
command: >
|
||||
--max_connections=1000
|
||||
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
|
||||
@ -193,7 +193,7 @@ services:
|
||||
"ping",
|
||||
"-u",
|
||||
"root",
|
||||
"-p${DB_PASSWORD:-difyai123456}",
|
||||
"-p${MYSQL_PASSWORD:-difyai123456}",
|
||||
]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
|
||||
@ -9,8 +9,8 @@ services:
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
environment:
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${DB_DATABASE:-dify}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-dify}
|
||||
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
|
||||
command: >
|
||||
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
|
||||
@ -32,9 +32,9 @@ services:
|
||||
"-h",
|
||||
"db_postgres",
|
||||
"-U",
|
||||
"${DB_USERNAME:-postgres}",
|
||||
"${PGUSER:-postgres}",
|
||||
"-d",
|
||||
"${DB_DATABASE:-dify}",
|
||||
"${POSTGRES_DB:-dify}",
|
||||
]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
@ -48,8 +48,8 @@ services:
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${DB_DATABASE:-dify}
|
||||
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
|
||||
command: >
|
||||
--max_connections=1000
|
||||
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
|
||||
@ -67,7 +67,7 @@ services:
|
||||
"ping",
|
||||
"-u",
|
||||
"root",
|
||||
"-p${DB_PASSWORD:-difyai123456}",
|
||||
"-p${MYSQL_PASSWORD:-difyai123456}",
|
||||
]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
|
||||
@ -455,7 +455,13 @@ x-shared-env: &shared-api-worker-env
|
||||
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
|
||||
ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
|
||||
MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50}
|
||||
POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}}
|
||||
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
|
||||
MYSQL_USERNAME: ${MYSQL_USERNAME:-${DB_USERNAME}}
|
||||
MYSQL_PASSWORD: ${MYSQL_PASSWORD:-${DB_PASSWORD}}
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE:-${DB_DATABASE}}
|
||||
MYSQL_HOST_VOLUME: ${MYSQL_HOST_VOLUME:-./volumes/mysql/data}
|
||||
SANDBOX_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox}
|
||||
SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release}
|
||||
@ -768,9 +774,9 @@ services:
|
||||
- postgresql
|
||||
restart: always
|
||||
environment:
|
||||
POSTGRES_USER: ${DB_USERNAME:-postgres}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${DB_DATABASE:-dify}
|
||||
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-dify}
|
||||
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
|
||||
command: >
|
||||
postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}'
|
||||
@ -790,7 +796,7 @@ services:
|
||||
"-h",
|
||||
"db_postgres",
|
||||
"-U",
|
||||
"${DB_USERNAME:-postgres}",
|
||||
"${PGUSER:-postgres}",
|
||||
"-d",
|
||||
"${DB_DATABASE:-dify}",
|
||||
]
|
||||
@ -805,8 +811,8 @@ services:
|
||||
- mysql
|
||||
restart: always
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${DB_DATABASE:-dify}
|
||||
MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456}
|
||||
MYSQL_DATABASE: ${MYSQL_DATABASE:-dify}
|
||||
command: >
|
||||
--max_connections=1000
|
||||
--innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M}
|
||||
@ -822,7 +828,7 @@ services:
|
||||
"ping",
|
||||
"-u",
|
||||
"root",
|
||||
"-p${DB_PASSWORD:-difyai123456}",
|
||||
"-p${MYSQL_PASSWORD:-difyai123456}",
|
||||
]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
# Database Configuration
|
||||
# Database type, supported values are `postgresql` and `mysql`
|
||||
DB_TYPE=postgresql
|
||||
# For MySQL, only `root` user is supported for now
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=db_postgres
|
||||
@ -12,6 +11,11 @@ DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
|
||||
# PostgreSQL Configuration
|
||||
POSTGRES_USER=${DB_USERNAME}
|
||||
# The password for the default postgres user.
|
||||
POSTGRES_PASSWORD=${DB_PASSWORD}
|
||||
# The name of the default postgres database.
|
||||
POSTGRES_DB=${DB_DATABASE}
|
||||
# postgres data directory
|
||||
PGDATA=/var/lib/postgresql/data/pgdata
|
||||
PGDATA_HOST_VOLUME=./volumes/db/data
|
||||
@ -61,6 +65,11 @@ POSTGRES_STATEMENT_TIMEOUT=0
|
||||
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
|
||||
|
||||
# MySQL Configuration
|
||||
MYSQL_USERNAME=${DB_USERNAME}
|
||||
# MySQL password
|
||||
MYSQL_PASSWORD=${DB_PASSWORD}
|
||||
# MySQL database name
|
||||
MYSQL_DATABASE=${DB_DATABASE}
|
||||
# MySQL data directory host volume
|
||||
MYSQL_HOST_VOLUME=./volumes/mysql/data
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ import type { ReactNode } from 'react'
|
||||
import SwrInitializer from '@/app/components/swr-initializer'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
import GA, { GaType } from '@/app/components/base/ga'
|
||||
import AmplitudeProvider from '@/app/components/base/amplitude'
|
||||
import HeaderWrapper from '@/app/components/header/header-wrapper'
|
||||
import Header from '@/app/components/header'
|
||||
import { EventEmitterContextProvider } from '@/context/event-emitter'
|
||||
@ -19,7 +18,6 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
||||
return (
|
||||
<>
|
||||
<GA gaType={GaType.admin} />
|
||||
<AmplitudeProvider />
|
||||
<SwrInitializer>
|
||||
<AppContextProvider>
|
||||
<EventEmitterContextProvider>
|
||||
|
||||
@ -8,7 +8,7 @@ const PluginList = async () => {
|
||||
return (
|
||||
<PluginPage
|
||||
plugins={<PluginsPanel />}
|
||||
marketplace={<Marketplace locale={locale} pluginTypeSwitchClassName='top-[60px]' showSearchParams={false} />}
|
||||
marketplace={<Marketplace locale={locale} pluginTypeSwitchClassName='top-[60px]' searchBoxAutoAnimate={false} showSearchParams={false} />}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
'use client'
|
||||
import { useState } from 'react'
|
||||
import useSWR from 'swr'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiGraduationCapFill,
|
||||
@ -22,9 +23,8 @@ import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import EmailChangeModal from './email-change-modal'
|
||||
import { validPassword } from '@/config'
|
||||
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import type { App } from '@/types/app'
|
||||
import { useAppList } from '@/service/use-apps'
|
||||
|
||||
const titleClassName = `
|
||||
system-sm-semibold text-text-secondary
|
||||
@ -36,7 +36,7 @@ const descriptionClassName = `
|
||||
export default function AccountPage() {
|
||||
const { t } = useTranslation()
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
const { data: appList } = useAppList({ page: 1, limit: 100, name: '' })
|
||||
const { data: appList } = useSWR({ url: '/apps', params: { page: 1, limit: 100, name: '' } }, fetchAppList)
|
||||
const apps = appList?.data || []
|
||||
const { mutateUserProfile, userProfile } = useAppContext()
|
||||
const { isEducationAccount } = useProviderContext()
|
||||
|
||||
@ -12,7 +12,6 @@ import { useProviderContext } from '@/context/provider-context'
|
||||
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import { useLogout } from '@/service/use-common'
|
||||
import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||
|
||||
export type IAppSelector = {
|
||||
isMobile: boolean
|
||||
@ -29,7 +28,6 @@ export default function AppSelector() {
|
||||
await logout()
|
||||
|
||||
localStorage.removeItem('setup_status')
|
||||
resetUser()
|
||||
// Tokens are now stored in cookies and cleared by backend
|
||||
|
||||
router.push('/signin')
|
||||
|
||||
@ -4,7 +4,6 @@ import Header from './header'
|
||||
import SwrInitor from '@/app/components/swr-initializer'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
import GA, { GaType } from '@/app/components/base/ga'
|
||||
import AmplitudeProvider from '@/app/components/base/amplitude'
|
||||
import HeaderWrapper from '@/app/components/header/header-wrapper'
|
||||
import { EventEmitterContextProvider } from '@/context/event-emitter'
|
||||
import { ProviderContextProvider } from '@/context/provider-context'
|
||||
@ -14,7 +13,6 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
||||
return (
|
||||
<>
|
||||
<GA gaType={GaType.admin} />
|
||||
<AmplitudeProvider />
|
||||
<SwrInitor>
|
||||
<AppContextProvider>
|
||||
<EventEmitterContextProvider>
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC, PropsWithChildren } from 'react'
|
||||
import useAccessControlStore from '@/context/access-control-store'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import type { AccessMode } from '@/models/access-control'
|
||||
|
||||
type AccessControlItemProps = PropsWithChildren<{
|
||||
@ -8,8 +8,7 @@ type AccessControlItemProps = PropsWithChildren<{
|
||||
}>
|
||||
|
||||
const AccessControlItem: FC<AccessControlItemProps> = ({ type, children }) => {
|
||||
const currentMenu = useAccessControlStore(s => s.currentMenu)
|
||||
const setCurrentMenu = useAccessControlStore(s => s.setCurrentMenu)
|
||||
const { currentMenu, setCurrentMenu } = useAccessControlStore(s => ({ currentMenu: s.currentMenu, setCurrentMenu: s.setCurrentMenu }))
|
||||
if (currentMenu !== type) {
|
||||
return <div
|
||||
className="cursor-pointer rounded-[10px] border-[1px]
|
||||
|
||||
@ -251,7 +251,6 @@ const AgentTools: FC = () => {
|
||||
{!item.notAuthor && (
|
||||
<Tooltip
|
||||
popupContent={t('tools.setBuiltInTools.infoAndSetting')}
|
||||
needsDelay={false}
|
||||
>
|
||||
<div className='cursor-pointer rounded-md p-1 hover:bg-black/5' onClick={() => {
|
||||
setCurrentTool(item)
|
||||
|
||||
@ -28,7 +28,6 @@ import Input from '@/app/components/base/input'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { DSLImportMode } from '@/models/app'
|
||||
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
|
||||
type AppsProps = {
|
||||
onSuccess?: () => void
|
||||
@ -142,15 +141,6 @@ const Apps = ({
|
||||
icon_background,
|
||||
description,
|
||||
})
|
||||
|
||||
// Track app creation from template
|
||||
trackEvent('create_app_with_template', {
|
||||
app_mode: mode,
|
||||
template_id: currApp?.app.id,
|
||||
template_name: currApp?.app.name,
|
||||
description,
|
||||
})
|
||||
|
||||
setIsShowCreateModal(false)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
|
||||
@ -30,7 +30,6 @@ import { getRedirection } from '@/utils/app-redirection'
|
||||
import FullScreenModal from '@/app/components/base/fullscreen-modal'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { useDocLink } from '@/context/i18n'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
|
||||
type CreateAppProps = {
|
||||
onSuccess: () => void
|
||||
@ -83,13 +82,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
|
||||
icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined,
|
||||
mode: appMode,
|
||||
})
|
||||
|
||||
// Track app creation success
|
||||
trackEvent('create_app', {
|
||||
app_mode: appMode,
|
||||
description,
|
||||
})
|
||||
|
||||
notify({ type: 'success', message: t('app.newApp.appCreated') })
|
||||
onSuccess()
|
||||
onClose()
|
||||
|
||||
@ -28,7 +28,6 @@ import { getRedirection } from '@/utils/app-redirection'
|
||||
import cn from '@/utils/classnames'
|
||||
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
|
||||
import { noop } from 'lodash-es'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
|
||||
type CreateFromDSLModalProps = {
|
||||
show: boolean
|
||||
@ -113,13 +112,6 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
return
|
||||
const { id, status, app_id, app_mode, imported_dsl_version, current_dsl_version } = response
|
||||
if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) {
|
||||
// Track app creation from DSL import
|
||||
trackEvent('create_app_with_dsl', {
|
||||
app_mode,
|
||||
creation_method: currentTab === CreateFromDSLModalTab.FROM_FILE ? 'dsl_file' : 'dsl_url',
|
||||
has_warnings: status === DSLImportStatus.COMPLETED_WITH_WARNINGS,
|
||||
})
|
||||
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
if (onClose)
|
||||
|
||||
@ -3,6 +3,7 @@ import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import ReactECharts from 'echarts-for-react'
|
||||
import type { EChartsOption } from 'echarts'
|
||||
import useSWR from 'swr'
|
||||
import type { Dayjs } from 'dayjs'
|
||||
import dayjs from 'dayjs'
|
||||
import { get } from 'lodash-es'
|
||||
@ -12,20 +13,7 @@ import { formatNumber } from '@/utils/format'
|
||||
import Basic from '@/app/components/app-sidebar/basic'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import type { AppDailyConversationsResponse, AppDailyEndUsersResponse, AppDailyMessagesResponse, AppTokenCostsResponse } from '@/models/app'
|
||||
import {
|
||||
useAppAverageResponseTime,
|
||||
useAppAverageSessionInteractions,
|
||||
useAppDailyConversations,
|
||||
useAppDailyEndUsers,
|
||||
useAppDailyMessages,
|
||||
useAppSatisfactionRate,
|
||||
useAppTokenCosts,
|
||||
useAppTokensPerSecond,
|
||||
useWorkflowAverageInteractions,
|
||||
useWorkflowDailyConversations,
|
||||
useWorkflowDailyTerminals,
|
||||
useWorkflowTokenCosts,
|
||||
} from '@/service/use-apps'
|
||||
import { getAppDailyConversations, getAppDailyEndUsers, getAppDailyMessages, getAppStatistics, getAppTokenCosts, getWorkflowDailyConversations } from '@/service/apps'
|
||||
const valueFormatter = (v: string | number) => v
|
||||
|
||||
const COLOR_TYPE_MAP = {
|
||||
@ -284,8 +272,8 @@ const getDefaultChartData = ({ start, end, key = 'count' }: { start: string; end
|
||||
|
||||
export const MessagesChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
const { data: response, isLoading } = useAppDailyMessages(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-messages`, params: period.query }, getAppDailyMessages)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -298,8 +286,8 @@ export const MessagesChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
|
||||
export const ConversationsChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
const { data: response, isLoading } = useAppDailyConversations(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-conversations`, params: period.query }, getAppDailyConversations)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -313,8 +301,8 @@ export const ConversationsChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
export const EndUsersChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const { data: response, isLoading } = useAppDailyEndUsers(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-end-users`, id, params: period.query }, getAppDailyEndUsers)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -327,8 +315,8 @@ export const EndUsersChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
|
||||
export const AvgSessionInteractions: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
const { data: response, isLoading } = useAppAverageSessionInteractions(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/statistics/average-session-interactions`, params: period.query }, getAppStatistics)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -343,8 +331,8 @@ export const AvgSessionInteractions: FC<IBizChartProps> = ({ id, period }) => {
|
||||
|
||||
export const AvgResponseTime: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
const { data: response, isLoading } = useAppAverageResponseTime(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/statistics/average-response-time`, params: period.query }, getAppStatistics)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -360,8 +348,8 @@ export const AvgResponseTime: FC<IBizChartProps> = ({ id, period }) => {
|
||||
|
||||
export const TokenPerSecond: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
const { data: response, isLoading } = useAppTokensPerSecond(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/statistics/tokens-per-second`, params: period.query }, getAppStatistics)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -378,8 +366,8 @@ export const TokenPerSecond: FC<IBizChartProps> = ({ id, period }) => {
|
||||
|
||||
export const UserSatisfactionRate: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
const { data: response, isLoading } = useAppSatisfactionRate(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/statistics/user-satisfaction-rate`, params: period.query }, getAppStatistics)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -396,8 +384,8 @@ export const UserSatisfactionRate: FC<IBizChartProps> = ({ id, period }) => {
|
||||
export const CostChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const { data: response, isLoading } = useAppTokenCosts(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/statistics/token-costs`, params: period.query }, getAppTokenCosts)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -410,8 +398,8 @@ export const CostChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
|
||||
export const WorkflowMessagesChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
const { data: response, isLoading } = useWorkflowDailyConversations(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/workflow/statistics/daily-conversations`, params: period.query }, getWorkflowDailyConversations)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -426,8 +414,8 @@ export const WorkflowMessagesChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
export const WorkflowDailyTerminalsChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const { data: response, isLoading } = useWorkflowDailyTerminals(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/workflow/statistics/daily-terminals`, id, params: period.query }, getAppDailyEndUsers)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -441,8 +429,8 @@ export const WorkflowDailyTerminalsChart: FC<IBizChartProps> = ({ id, period })
|
||||
export const WorkflowCostChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const { data: response, isLoading } = useWorkflowTokenCosts(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/workflow/statistics/token-costs`, params: period.query }, getAppTokenCosts)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
@ -455,8 +443,8 @@ export const WorkflowCostChart: FC<IBizChartProps> = ({ id, period }) => {
|
||||
|
||||
export const AvgUserInteractions: FC<IBizChartProps> = ({ id, period }) => {
|
||||
const { t } = useTranslation()
|
||||
const { data: response, isLoading } = useWorkflowAverageInteractions(id, period.query)
|
||||
if (isLoading || !response)
|
||||
const { data: response } = useSWR({ url: `/apps/${id}/workflow/statistics/average-app-interactions`, params: period.query }, getAppStatistics)
|
||||
if (!response)
|
||||
return <Loading />
|
||||
const noDataFlag = !response.data || response.data.length === 0
|
||||
return <Chart
|
||||
|
||||
@ -8,7 +8,6 @@ import quarterOfYear from 'dayjs/plugin/quarterOfYear'
|
||||
import type { QueryParam } from './index'
|
||||
import Chip from '@/app/components/base/chip'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { trackEvent } from '@/app/components/base/amplitude/utils'
|
||||
dayjs.extend(quarterOfYear)
|
||||
|
||||
const today = dayjs()
|
||||
@ -38,9 +37,6 @@ const Filter: FC<IFilterProps> = ({ queryParams, setQueryParams }: IFilterProps)
|
||||
value={queryParams.status || 'all'}
|
||||
onSelect={(item) => {
|
||||
setQueryParams({ ...queryParams, status: item.value as string })
|
||||
trackEvent('workflow_log_filter_status_selected', {
|
||||
workflow_log_filter_status: item.value as string,
|
||||
})
|
||||
}}
|
||||
onClear={() => setQueryParams({ ...queryParams, status: 'all' })}
|
||||
items={[{ value: 'all', name: 'All' },
|
||||
|
||||
@ -23,7 +23,7 @@ const Empty = () => {
|
||||
return (
|
||||
<>
|
||||
<DefaultCards />
|
||||
<div className='pointer-events-none absolute inset-0 z-20 flex items-center justify-center bg-gradient-to-t from-background-body to-transparent'>
|
||||
<div className='absolute bottom-0 left-0 right-0 top-0 flex items-center justify-center bg-gradient-to-t from-background-body to-transparent'>
|
||||
<span className='system-md-medium text-text-tertiary'>
|
||||
{t('app.newApp.noAppsFound')}
|
||||
</span>
|
||||
|
||||
@ -4,6 +4,7 @@ import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import {
|
||||
useRouter,
|
||||
} from 'next/navigation'
|
||||
import useSWRInfinite from 'swr/infinite'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import {
|
||||
@ -18,6 +19,8 @@ import AppCard from './app-card'
|
||||
import NewAppCard from './new-app-card'
|
||||
import useAppsQueryState from './hooks/use-apps-query-state'
|
||||
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
|
||||
import type { AppListResponse } from '@/models/app'
|
||||
import { fetchAppList } from '@/service/apps'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { CheckModal } from '@/hooks/use-pay'
|
||||
@ -32,7 +35,6 @@ import Empty from './empty'
|
||||
import Footer from './footer'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { useInfiniteAppList } from '@/service/use-apps'
|
||||
|
||||
const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), {
|
||||
ssr: false,
|
||||
@ -41,6 +43,30 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro
|
||||
ssr: false,
|
||||
})
|
||||
|
||||
const getKey = (
|
||||
pageIndex: number,
|
||||
previousPageData: AppListResponse,
|
||||
activeTab: string,
|
||||
isCreatedByMe: boolean,
|
||||
tags: string[],
|
||||
keywords: string,
|
||||
) => {
|
||||
if (!pageIndex || previousPageData.has_more) {
|
||||
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords, is_created_by_me: isCreatedByMe } }
|
||||
|
||||
if (activeTab !== 'all')
|
||||
params.params.mode = activeTab
|
||||
else
|
||||
delete params.params.mode
|
||||
|
||||
if (tags.length)
|
||||
params.params.tag_ids = tags
|
||||
|
||||
return params
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
const List = () => {
|
||||
const { t } = useTranslation()
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
@ -76,24 +102,16 @@ const List = () => {
|
||||
enabled: isCurrentWorkspaceEditor,
|
||||
})
|
||||
|
||||
const appListQueryParams = {
|
||||
page: 1,
|
||||
limit: 30,
|
||||
name: searchKeywords,
|
||||
tag_ids: tagIDs,
|
||||
is_created_by_me: isCreatedByMe,
|
||||
...(activeTab !== 'all' ? { mode: activeTab as AppModeEnum } : {}),
|
||||
}
|
||||
|
||||
const {
|
||||
data,
|
||||
isLoading,
|
||||
isFetchingNextPage,
|
||||
fetchNextPage,
|
||||
hasNextPage,
|
||||
error,
|
||||
refetch,
|
||||
} = useInfiniteAppList(appListQueryParams, { enabled: !isCurrentWorkspaceDatasetOperator })
|
||||
const { data, isLoading, error, setSize, mutate } = useSWRInfinite(
|
||||
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords),
|
||||
fetchAppList,
|
||||
{
|
||||
revalidateFirstPage: true,
|
||||
shouldRetryOnError: false,
|
||||
dedupingInterval: 500,
|
||||
errorRetryCount: 3,
|
||||
},
|
||||
)
|
||||
|
||||
const anchorRef = useRef<HTMLDivElement>(null)
|
||||
const options = [
|
||||
@ -108,9 +126,9 @@ const List = () => {
|
||||
useEffect(() => {
|
||||
if (localStorage.getItem(NEED_REFRESH_APP_LIST_KEY) === '1') {
|
||||
localStorage.removeItem(NEED_REFRESH_APP_LIST_KEY)
|
||||
refetch()
|
||||
mutate()
|
||||
}
|
||||
}, [refetch])
|
||||
}, [mutate, t])
|
||||
|
||||
useEffect(() => {
|
||||
if (isCurrentWorkspaceDatasetOperator)
|
||||
@ -118,9 +136,7 @@ const List = () => {
|
||||
}, [router, isCurrentWorkspaceDatasetOperator])
|
||||
|
||||
useEffect(() => {
|
||||
if (isCurrentWorkspaceDatasetOperator)
|
||||
return
|
||||
const hasMore = hasNextPage ?? true
|
||||
const hasMore = data?.at(-1)?.has_more ?? true
|
||||
let observer: IntersectionObserver | undefined
|
||||
|
||||
if (error) {
|
||||
@ -135,8 +151,8 @@ const List = () => {
|
||||
const dynamicMargin = Math.max(100, Math.min(containerHeight * 0.2, 200)) // Clamps to 100-200px range, using 20% of container height as the base value
|
||||
|
||||
observer = new IntersectionObserver((entries) => {
|
||||
if (entries[0].isIntersecting && !isLoading && !isFetchingNextPage && !error && hasMore)
|
||||
fetchNextPage()
|
||||
if (entries[0].isIntersecting && !isLoading && !error && hasMore)
|
||||
setSize((size: number) => size + 1)
|
||||
}, {
|
||||
root: containerRef.current,
|
||||
rootMargin: `${dynamicMargin}px`,
|
||||
@ -145,7 +161,7 @@ const List = () => {
|
||||
observer.observe(anchorRef.current)
|
||||
}
|
||||
return () => observer?.disconnect()
|
||||
}, [isLoading, isFetchingNextPage, fetchNextPage, error, hasNextPage, isCurrentWorkspaceDatasetOperator])
|
||||
}, [isLoading, setSize, data, error])
|
||||
|
||||
const { run: handleSearch } = useDebounceFn(() => {
|
||||
setSearchKeywords(keywords)
|
||||
@ -169,9 +185,6 @@ const List = () => {
|
||||
setQuery(prev => ({ ...prev, isCreatedByMe: newValue }))
|
||||
}, [isCreatedByMe, setQuery])
|
||||
|
||||
const pages = data?.pages ?? []
|
||||
const hasAnyApp = (pages[0]?.total ?? 0) > 0
|
||||
|
||||
return (
|
||||
<>
|
||||
<div ref={containerRef} className='relative flex h-0 shrink-0 grow flex-col overflow-y-auto bg-background-body'>
|
||||
@ -204,17 +217,17 @@ const List = () => {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{hasAnyApp
|
||||
{(data && data[0].total > 0)
|
||||
? <div className='relative grid grow grid-cols-1 content-start gap-4 px-12 pt-2 sm:grid-cols-1 md:grid-cols-2 xl:grid-cols-4 2xl:grid-cols-5 2k:grid-cols-6'>
|
||||
{isCurrentWorkspaceEditor
|
||||
&& <NewAppCard ref={newAppCardRef} onSuccess={refetch} selectedAppType={activeTab} />}
|
||||
{pages.map(({ data: apps }) => apps.map(app => (
|
||||
<AppCard key={app.id} app={app} onRefresh={refetch} />
|
||||
&& <NewAppCard ref={newAppCardRef} onSuccess={mutate} selectedAppType={activeTab} />}
|
||||
{data.map(({ data: apps }) => apps.map(app => (
|
||||
<AppCard key={app.id} app={app} onRefresh={mutate} />
|
||||
)))}
|
||||
</div>
|
||||
: <div className='relative grid grow grid-cols-1 content-start gap-4 overflow-hidden px-12 pt-2 sm:grid-cols-1 md:grid-cols-2 xl:grid-cols-4 2xl:grid-cols-5 2k:grid-cols-6'>
|
||||
{isCurrentWorkspaceEditor
|
||||
&& <NewAppCard ref={newAppCardRef} className='z-10' onSuccess={refetch} selectedAppType={activeTab} />}
|
||||
&& <NewAppCard ref={newAppCardRef} className='z-10' onSuccess={mutate} selectedAppType={activeTab} />}
|
||||
<Empty />
|
||||
</div>}
|
||||
|
||||
@ -248,7 +261,7 @@ const List = () => {
|
||||
onSuccess={() => {
|
||||
setShowCreateFromDSLModal(false)
|
||||
setDroppedDSLFile(undefined)
|
||||
refetch()
|
||||
mutate()
|
||||
}}
|
||||
droppedFile={droppedDSLFile}
|
||||
/>
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import React, { useEffect } from 'react'
|
||||
import * as amplitude from '@amplitude/analytics-browser'
|
||||
import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
|
||||
export type IAmplitudeProps = {
|
||||
apiKey?: string
|
||||
sessionReplaySampleRate?: number
|
||||
}
|
||||
|
||||
const AmplitudeProvider: FC<IAmplitudeProps> = ({
|
||||
apiKey = process.env.NEXT_PUBLIC_AMPLITUDE_API_KEY ?? '',
|
||||
sessionReplaySampleRate = 1,
|
||||
}) => {
|
||||
useEffect(() => {
|
||||
// Only enable in Saas edition
|
||||
if (!IS_CLOUD_EDITION)
|
||||
return
|
||||
|
||||
// Initialize Amplitude
|
||||
amplitude.init(apiKey, {
|
||||
defaultTracking: {
|
||||
sessions: true,
|
||||
pageViews: true,
|
||||
formInteractions: true,
|
||||
fileDownloads: true,
|
||||
},
|
||||
// Enable debug logs in development environment
|
||||
logLevel: amplitude.Types.LogLevel.Warn,
|
||||
})
|
||||
|
||||
// Add Session Replay plugin
|
||||
const sessionReplay = sessionReplayPlugin({
|
||||
sampleRate: sessionReplaySampleRate,
|
||||
})
|
||||
amplitude.add(sessionReplay)
|
||||
}, [])
|
||||
|
||||
// This is a client component that renders nothing
|
||||
return null
|
||||
}
|
||||
|
||||
export default React.memo(AmplitudeProvider)
|
||||
@ -1,2 +0,0 @@
|
||||
export { default } from './AmplitudeProvider'
|
||||
export { resetUser, setUserId, setUserProperties, trackEvent } from './utils'
|
||||
@ -1,37 +0,0 @@
|
||||
import * as amplitude from '@amplitude/analytics-browser'
|
||||
|
||||
/**
|
||||
* Track custom event
|
||||
* @param eventName Event name
|
||||
* @param eventProperties Event properties (optional)
|
||||
*/
|
||||
export const trackEvent = (eventName: string, eventProperties?: Record<string, any>) => {
|
||||
amplitude.track(eventName, eventProperties)
|
||||
}
|
||||
|
||||
/**
|
||||
* Set user ID
|
||||
* @param userId User ID
|
||||
*/
|
||||
export const setUserId = (userId: string) => {
|
||||
amplitude.setUserId(userId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Set user properties
|
||||
* @param properties User properties
|
||||
*/
|
||||
export const setUserProperties = (properties: Record<string, any>) => {
|
||||
const identifyEvent = new amplitude.Identify()
|
||||
Object.entries(properties).forEach(([key, value]) => {
|
||||
identifyEvent.set(key, value)
|
||||
})
|
||||
amplitude.identify(identifyEvent)
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset user (e.g., when user logs out)
|
||||
*/
|
||||
export const resetUser = () => {
|
||||
amplitude.reset()
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user