mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
Merge branch 'feat/model-total-credits' into deploy/dev
This commit is contained in:
@ -1,62 +1,59 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import AppIconUrlField
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
JSONObject: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
def build_system_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the system parameters model for the API or Namespace."""
|
||||
return api_or_ns.model("SystemParameters", parameters__system_parameters)
|
||||
class SystemParameters(BaseModel):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
audio_file_size_limit: int
|
||||
file_size_limit: int
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
||||
class Parameters(BaseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: JSONObject
|
||||
speech_to_text: JSONObject
|
||||
text_to_speech: JSONObject
|
||||
retriever_resource: JSONObject
|
||||
annotation_reply: JSONObject
|
||||
more_like_this: JSONObject
|
||||
user_input_form: list[JSONObject]
|
||||
sensitive_word_avoidance: JSONObject
|
||||
file_upload: JSONObject
|
||||
system_parameters: SystemParameters
|
||||
|
||||
|
||||
def build_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the parameters model for the API or Namespace."""
|
||||
copied_fields = parameters_fields.copy()
|
||||
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
|
||||
return api_or_ns.model("Parameters", copied_fields)
|
||||
class Site(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
title: str
|
||||
chat_color_theme: str | None = None
|
||||
chat_color_theme_inverted: bool
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
default_language: str
|
||||
show_workflow_steps: bool
|
||||
use_icon_as_answer_icon: bool
|
||||
|
||||
site_fields = {
|
||||
"title": fields.String,
|
||||
"chat_color_theme": fields.String,
|
||||
"chat_color_theme_inverted": fields.Boolean,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"description": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"default_language": fields.String,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
|
||||
def build_site_model(api_or_ns: Api | Namespace):
|
||||
"""Build the site model for the API or Namespace."""
|
||||
return api_or_ns.model("Site", site_fields)
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
if self.icon and self.icon_type == IconType.IMAGE:
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
return None
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
@ -73,6 +74,48 @@ class AppListQuery(BaseModel):
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
# XSS prevention: patterns that could lead to XSS attacks
|
||||
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
|
||||
_XSS_PATTERNS = [
|
||||
r"<script[^>]*>.*?</script>", # Script tags
|
||||
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
|
||||
r"javascript:", # JavaScript protocol
|
||||
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
|
||||
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
|
||||
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
|
||||
r"<embed[^>]*>", # Embed tags (self-closing)
|
||||
r"<link[^>]*>", # Link tags with javascript
|
||||
]
|
||||
|
||||
|
||||
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
|
||||
"""
|
||||
Validate that a string value doesn't contain potential XSS payloads.
|
||||
|
||||
Args:
|
||||
value: The string value to validate
|
||||
field_name: Name of the field for error messages
|
||||
|
||||
Returns:
|
||||
The original value if safe
|
||||
|
||||
Raises:
|
||||
ValueError: If the value contains XSS patterns
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
value_lower = value.lower()
|
||||
for pattern in _XSS_PATTERNS:
|
||||
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
|
||||
raise ValueError(
|
||||
f"{field_name} contains invalid characters or patterns. "
|
||||
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@ -81,6 +124,11 @@ class CreateAppPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
@ -91,6 +139,11 @@ class UpdateAppPayload(BaseModel):
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
@ -99,6 +152,11 @@ class CopyAppPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
|
||||
@ -3,10 +3,12 @@ import uuid
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import String, cast, func, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -143,7 +145,29 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
# Search in both content and keywords fields
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||
keywords_condition = func.array_to_string(
|
||||
func.array(
|
||||
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||
.correlate(DocumentSegment)
|
||||
.scalar_subquery()
|
||||
),
|
||||
",",
|
||||
).ilike(f"%{keyword}%")
|
||||
else:
|
||||
# MySQL: Cast JSON to string for pattern matching
|
||||
# MySQL stores Chinese text directly in JSON without Unicode escaping
|
||||
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
|
||||
|
||||
query = query.where(
|
||||
or_(
|
||||
DocumentSegment.content.ilike(f"%{keyword}%"),
|
||||
keywords_condition,
|
||||
)
|
||||
)
|
||||
|
||||
if args.enabled.lower() != "all":
|
||||
if args.enabled.lower() == "true":
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from flask_restx import marshal_with
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
@ -13,7 +11,6 @@ from services.app_service import AppService
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -92,7 +92,7 @@ annotation_list_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_list_model(api_or_ns: Api | Namespace):
|
||||
def build_annotation_list_model(api_or_ns: Namespace):
|
||||
"""Build the annotation list model for the API or Namespace."""
|
||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.fields import build_parameters_model
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@ -23,7 +23,6 @@ class AppParameterApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters.
|
||||
|
||||
@ -45,7 +44,8 @@ class AppParameterApi(Resource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/meta")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.fields import build_site_model
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
@ -23,7 +23,6 @@ class AppSiteApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app site info.
|
||||
|
||||
@ -38,4 +37,4 @@ class AppSiteApi(Resource):
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
@ -78,7 +78,7 @@ workflow_run_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_run_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_run_model(api_or_ns: Namespace):
|
||||
"""Build the workflow run model for the API or Namespace."""
|
||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/meta")
|
||||
|
||||
Reference in New Issue
Block a user