mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 11:45:05 +08:00
refactor: migrate some ns.model to BaseModel (#30388)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,62 +1,59 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from __future__ import annotations
|
||||||
|
|
||||||
from libs.helper import AppIconUrlField
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
parameters__system_parameters = {
|
from pydantic import BaseModel, ConfigDict, computed_field
|
||||||
"image_file_size_limit": fields.Integer,
|
|
||||||
"video_file_size_limit": fields.Integer,
|
from core.file import helpers as file_helpers
|
||||||
"audio_file_size_limit": fields.Integer,
|
from models.model import IconType
|
||||||
"file_size_limit": fields.Integer,
|
|
||||||
"workflow_file_upload_limit": fields.Integer,
|
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||||
}
|
JSONObject: TypeAlias = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
def build_system_parameters_model(api_or_ns: Api | Namespace):
|
class SystemParameters(BaseModel):
|
||||||
"""Build the system parameters model for the API or Namespace."""
|
image_file_size_limit: int
|
||||||
return api_or_ns.model("SystemParameters", parameters__system_parameters)
|
video_file_size_limit: int
|
||||||
|
audio_file_size_limit: int
|
||||||
|
file_size_limit: int
|
||||||
|
workflow_file_upload_limit: int
|
||||||
|
|
||||||
|
|
||||||
parameters_fields = {
|
class Parameters(BaseModel):
|
||||||
"opening_statement": fields.String,
|
opening_statement: str | None = None
|
||||||
"suggested_questions": fields.Raw,
|
suggested_questions: list[str]
|
||||||
"suggested_questions_after_answer": fields.Raw,
|
suggested_questions_after_answer: JSONObject
|
||||||
"speech_to_text": fields.Raw,
|
speech_to_text: JSONObject
|
||||||
"text_to_speech": fields.Raw,
|
text_to_speech: JSONObject
|
||||||
"retriever_resource": fields.Raw,
|
retriever_resource: JSONObject
|
||||||
"annotation_reply": fields.Raw,
|
annotation_reply: JSONObject
|
||||||
"more_like_this": fields.Raw,
|
more_like_this: JSONObject
|
||||||
"user_input_form": fields.Raw,
|
user_input_form: list[JSONObject]
|
||||||
"sensitive_word_avoidance": fields.Raw,
|
sensitive_word_avoidance: JSONObject
|
||||||
"file_upload": fields.Raw,
|
file_upload: JSONObject
|
||||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
system_parameters: SystemParameters
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_parameters_model(api_or_ns: Api | Namespace):
|
class Site(BaseModel):
|
||||||
"""Build the parameters model for the API or Namespace."""
|
model_config = ConfigDict(from_attributes=True)
|
||||||
copied_fields = parameters_fields.copy()
|
|
||||||
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
|
|
||||||
return api_or_ns.model("Parameters", copied_fields)
|
|
||||||
|
|
||||||
|
title: str
|
||||||
|
chat_color_theme: str | None = None
|
||||||
|
chat_color_theme_inverted: bool
|
||||||
|
icon_type: str | None = None
|
||||||
|
icon: str | None = None
|
||||||
|
icon_background: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
copyright: str | None = None
|
||||||
|
privacy_policy: str | None = None
|
||||||
|
custom_disclaimer: str | None = None
|
||||||
|
default_language: str
|
||||||
|
show_workflow_steps: bool
|
||||||
|
use_icon_as_answer_icon: bool
|
||||||
|
|
||||||
site_fields = {
|
@computed_field(return_type=str | None) # type: ignore
|
||||||
"title": fields.String,
|
@property
|
||||||
"chat_color_theme": fields.String,
|
def icon_url(self) -> str | None:
|
||||||
"chat_color_theme_inverted": fields.Boolean,
|
if self.icon and self.icon_type == IconType.IMAGE:
|
||||||
"icon_type": fields.String,
|
return file_helpers.get_signed_file_url(self.icon)
|
||||||
"icon": fields.String,
|
return None
|
||||||
"icon_background": fields.String,
|
|
||||||
"icon_url": AppIconUrlField,
|
|
||||||
"description": fields.String,
|
|
||||||
"copyright": fields.String,
|
|
||||||
"privacy_policy": fields.String,
|
|
||||||
"custom_disclaimer": fields.String,
|
|
||||||
"default_language": fields.String,
|
|
||||||
"show_workflow_steps": fields.Boolean,
|
|
||||||
"use_icon_as_answer_icon": fields.Boolean,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_site_model(api_or_ns: Api | Namespace):
|
|
||||||
"""Build the site model for the API or Namespace."""
|
|
||||||
return api_or_ns.model("Site", site_fields)
|
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from flask_restx import marshal_with
|
|
||||||
|
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import AppUnavailableError
|
from controllers.console.app.error import AppUnavailableError
|
||||||
@ -13,7 +11,6 @@ from services.app_service import AppService
|
|||||||
class AppParameterApi(InstalledAppResource):
|
class AppParameterApi(InstalledAppResource):
|
||||||
"""Resource for app variables."""
|
"""Resource for app variables."""
|
||||||
|
|
||||||
@marshal_with(fields.parameters_fields)
|
|
||||||
def get(self, installed_app: InstalledApp):
|
def get(self, installed_app: InstalledApp):
|
||||||
"""Retrieve app parameters."""
|
"""Retrieve app parameters."""
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
|
|||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
|
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Api, Namespace, Resource, fields
|
from flask_restx import Namespace, Resource, fields
|
||||||
from flask_restx.api import HTTPStatus
|
from flask_restx.api import HTTPStatus
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -92,7 +92,7 @@ annotation_list_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_annotation_list_model(api_or_ns: Api | Namespace):
|
def build_annotation_list_model(api_or_ns: Namespace):
|
||||||
"""Build the annotation list model for the API or Namespace."""
|
"""Build the annotation list model for the API or Namespace."""
|
||||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
|
|
||||||
from controllers.common.fields import build_parameters_model
|
from controllers.common.fields import Parameters
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import AppUnavailableError
|
from controllers.service_api.app.error import AppUnavailableError
|
||||||
from controllers.service_api.wraps import validate_app_token
|
from controllers.service_api.wraps import validate_app_token
|
||||||
@ -23,7 +23,6 @@ class AppParameterApi(Resource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_app_token
|
@validate_app_token
|
||||||
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
|
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""Retrieve app parameters.
|
"""Retrieve app parameters.
|
||||||
|
|
||||||
@ -45,7 +44,8 @@ class AppParameterApi(Resource):
|
|||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
|
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route("/meta")
|
@service_api_ns.route("/meta")
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.common.fields import build_site_model
|
from controllers.common.fields import Site as SiteResponse
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.wraps import validate_app_token
|
from controllers.service_api.wraps import validate_app_token
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -23,7 +23,6 @@ class AppSiteApi(Resource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
@validate_app_token
|
@validate_app_token
|
||||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""Retrieve app site info.
|
"""Retrieve app site info.
|
||||||
|
|
||||||
@ -38,4 +37,4 @@ class AppSiteApi(Resource):
|
|||||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
return site
|
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Any, Literal
|
|||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil.parser import isoparse
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Api, Namespace, Resource, fields
|
from flask_restx import Namespace, Resource, fields
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||||
@ -78,7 +78,7 @@ workflow_run_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_workflow_run_model(api_or_ns: Api | Namespace):
|
def build_workflow_run_model(api_or_ns: Namespace):
|
||||||
"""Build the workflow run model for the API or Namespace."""
|
"""Build the workflow run model for the API or Namespace."""
|
||||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, marshal_with
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
|
|||||||
500: "Internal Server Error",
|
500: "Internal Server Error",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@marshal_with(fields.parameters_fields)
|
|
||||||
def get(self, app_model: App, end_user):
|
def get(self, app_model: App, end_user):
|
||||||
"""Retrieve app parameters."""
|
"""Retrieve app parameters."""
|
||||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||||
@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
|
|||||||
|
|
||||||
user_input_form = features_dict.get("user_input_form", [])
|
user_input_form = features_dict.get("user_input_form", [])
|
||||||
|
|
||||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
|
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/meta")
|
@web_ns.route("/meta")
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ annotation_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_annotation_model(api_or_ns: Api | Namespace):
|
def build_annotation_model(api_or_ns: Namespace):
|
||||||
"""Build the annotation model for the API or Namespace."""
|
"""Build the annotation model for the API or Namespace."""
|
||||||
return api_or_ns.model("Annotation", annotation_fields)
|
return api_or_ns.model("Annotation", annotation_fields)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from fields.member_fields import simple_account_fields
|
from fields.member_fields import simple_account_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
@ -46,7 +46,7 @@ message_file_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_message_file_model(api_or_ns: Api | Namespace):
|
def build_message_file_model(api_or_ns: Namespace):
|
||||||
"""Build the message file fields for the API or Namespace."""
|
"""Build the message file fields for the API or Namespace."""
|
||||||
return api_or_ns.model("MessageFile", message_file_fields)
|
return api_or_ns.model("MessageFile", message_file_fields)
|
||||||
|
|
||||||
@ -217,7 +217,7 @@ conversation_infinite_scroll_pagination_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||||
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
|
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
|
||||||
simple_conversation_model = build_simple_conversation_model(api_or_ns)
|
simple_conversation_model = build_simple_conversation_model(api_or_ns)
|
||||||
|
|
||||||
@ -226,11 +226,11 @@ def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespa
|
|||||||
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
|
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
|
||||||
|
|
||||||
|
|
||||||
def build_conversation_delete_model(api_or_ns: Api | Namespace):
|
def build_conversation_delete_model(api_or_ns: Namespace):
|
||||||
"""Build the conversation delete model for the API or Namespace."""
|
"""Build the conversation delete model for the API or Namespace."""
|
||||||
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
|
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
|
||||||
|
|
||||||
|
|
||||||
def build_simple_conversation_model(api_or_ns: Api | Namespace):
|
def build_simple_conversation_model(api_or_ns: Namespace):
|
||||||
"""Build the simple conversation model for the API or Namespace."""
|
"""Build the simple conversation model for the API or Namespace."""
|
||||||
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
|
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_conversation_variable_model(api_or_ns: Api | Namespace):
|
def build_conversation_variable_model(api_or_ns: Namespace):
|
||||||
"""Build the conversation variable model for the API or Namespace."""
|
"""Build the conversation variable model for the API or Namespace."""
|
||||||
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
|
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
|
||||||
|
|
||||||
|
|
||||||
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||||
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
|
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
|
||||||
# Build the nested variable model first
|
# Build the nested variable model first
|
||||||
conversation_variable_model = build_conversation_variable_model(api_or_ns)
|
conversation_variable_model = build_conversation_variable_model(api_or_ns)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
simple_end_user_fields = {
|
simple_end_user_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
@ -8,5 +8,5 @@ simple_end_user_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_simple_end_user_model(api_or_ns: Api | Namespace):
|
def build_simple_end_user_model(api_or_ns: Namespace):
|
||||||
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ upload_config_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_upload_config_model(api_or_ns: Api | Namespace):
|
def build_upload_config_model(api_or_ns: Namespace):
|
||||||
"""Build the upload config model for the API or Namespace.
|
"""Build the upload config model for the API or Namespace.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -39,7 +39,7 @@ file_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_file_model(api_or_ns: Api | Namespace):
|
def build_file_model(api_or_ns: Namespace):
|
||||||
"""Build the file model for the API or Namespace.
|
"""Build the file model for the API or Namespace.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -57,7 +57,7 @@ remote_file_info_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_remote_file_info_model(api_or_ns: Api | Namespace):
|
def build_remote_file_info_model(api_or_ns: Namespace):
|
||||||
"""Build the remote file info model for the API or Namespace.
|
"""Build the remote file info model for the API or Namespace.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -81,7 +81,7 @@ file_fields_with_signed_url = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
|
def build_file_with_signed_url_model(api_or_ns: Namespace):
|
||||||
"""Build the file with signed URL model for the API or Namespace.
|
"""Build the file with signed URL model for the API or Namespace.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from libs.helper import AvatarUrlField, TimestampField
|
from libs.helper import AvatarUrlField, TimestampField
|
||||||
|
|
||||||
@ -9,7 +9,7 @@ simple_account_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_simple_account_model(api_or_ns: Api | Namespace):
|
def build_simple_account_model(api_or_ns: Namespace):
|
||||||
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
@ -10,7 +10,7 @@ feedback_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_feedback_model(api_or_ns: Api | Namespace):
|
def build_feedback_model(api_or_ns: Namespace):
|
||||||
"""Build the feedback model for the API or Namespace."""
|
"""Build the feedback model for the API or Namespace."""
|
||||||
return api_or_ns.model("Feedback", feedback_fields)
|
return api_or_ns.model("Feedback", feedback_fields)
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ agent_thought_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_agent_thought_model(api_or_ns: Api | Namespace):
|
def build_agent_thought_model(api_or_ns: Namespace):
|
||||||
"""Build the agent thought model for the API or Namespace."""
|
"""Build the agent thought model for the API or Namespace."""
|
||||||
return api_or_ns.model("AgentThought", agent_thought_fields)
|
return api_or_ns.model("AgentThought", agent_thought_fields)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
dataset_tag_fields = {
|
dataset_tag_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
@ -8,5 +8,5 @@ dataset_tag_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_dataset_tag_fields(api_or_ns: Api | Namespace):
|
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
||||||
from fields.member_fields import build_simple_account_model, simple_account_fields
|
from fields.member_fields import build_simple_account_model, simple_account_fields
|
||||||
@ -17,7 +17,7 @@ workflow_app_log_partial_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
|
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
|
||||||
"""Build the workflow app log partial model for the API or Namespace."""
|
"""Build the workflow app log partial model for the API or Namespace."""
|
||||||
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
||||||
simple_account_model = build_simple_account_model(api_or_ns)
|
simple_account_model = build_simple_account_model(api_or_ns)
|
||||||
@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
|
def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
|
||||||
"""Build the workflow app log pagination model for the API or Namespace."""
|
"""Build the workflow app log pagination model for the API or Namespace."""
|
||||||
# Build the nested partial model first
|
# Build the nested partial model first
|
||||||
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
|
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from flask_restx import Api, Namespace, fields
|
from flask_restx import Namespace, fields
|
||||||
|
|
||||||
from fields.end_user_fields import simple_end_user_fields
|
from fields.end_user_fields import simple_end_user_fields
|
||||||
from fields.member_fields import simple_account_fields
|
from fields.member_fields import simple_account_fields
|
||||||
@ -19,7 +19,7 @@ workflow_run_for_log_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
|
def build_workflow_run_for_log_model(api_or_ns: Namespace):
|
||||||
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
|
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from uuid import uuid4
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from sqlalchemy import DateTime, String, func, select
|
from sqlalchemy import DateTime, String, func, select
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column, validates
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from .base import TypeBase
|
from .base import TypeBase
|
||||||
@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase):
|
|||||||
role: TenantAccountRole | None = field(default=None, init=False)
|
role: TenantAccountRole | None = field(default=None, init=False)
|
||||||
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
||||||
|
|
||||||
|
@validates("status")
|
||||||
|
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
|
||||||
|
if isinstance(value, AccountStatus):
|
||||||
|
return value.value
|
||||||
|
return value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_password_set(self):
|
def is_password_set(self):
|
||||||
return self.password is not None
|
return self.password is not None
|
||||||
|
|||||||
69
api/tests/unit_tests/controllers/common/test_fields.py
Normal file
69
api/tests/unit_tests/controllers/common/test_fields.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import builtins
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from flask.views import MethodView as FlaskMethodView
|
||||||
|
|
||||||
|
_NEEDS_METHOD_VIEW_CLEANUP = False
|
||||||
|
if not hasattr(builtins, "MethodView"):
|
||||||
|
builtins.MethodView = FlaskMethodView
|
||||||
|
_NEEDS_METHOD_VIEW_CLEANUP = True
|
||||||
|
from controllers.common.fields import Parameters, Site
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
|
from models.model import IconType
|
||||||
|
|
||||||
|
|
||||||
|
def test_parameters_model_round_trip():
|
||||||
|
parameters = get_parameters_from_feature_dict(features_dict={}, user_input_form=[])
|
||||||
|
|
||||||
|
model = Parameters.model_validate(parameters)
|
||||||
|
|
||||||
|
assert model.model_dump(mode="json") == parameters
|
||||||
|
|
||||||
|
|
||||||
|
def test_site_icon_url_uses_signed_url_for_image_icon():
|
||||||
|
site = SimpleNamespace(
|
||||||
|
title="Example",
|
||||||
|
chat_color_theme=None,
|
||||||
|
chat_color_theme_inverted=False,
|
||||||
|
icon_type=IconType.IMAGE,
|
||||||
|
icon="file-id",
|
||||||
|
icon_background=None,
|
||||||
|
description=None,
|
||||||
|
copyright=None,
|
||||||
|
privacy_policy=None,
|
||||||
|
custom_disclaimer=None,
|
||||||
|
default_language="en-US",
|
||||||
|
show_workflow_steps=True,
|
||||||
|
use_icon_as_answer_icon=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("controllers.common.fields.file_helpers.get_signed_file_url", return_value="signed") as mock_helper:
|
||||||
|
model = Site.model_validate(site)
|
||||||
|
|
||||||
|
assert model.icon_url == "signed"
|
||||||
|
mock_helper.assert_called_once_with("file-id")
|
||||||
|
|
||||||
|
|
||||||
|
def test_site_icon_url_is_none_for_non_image_icon():
|
||||||
|
site = SimpleNamespace(
|
||||||
|
title="Example",
|
||||||
|
chat_color_theme=None,
|
||||||
|
chat_color_theme_inverted=False,
|
||||||
|
icon_type=IconType.EMOJI,
|
||||||
|
icon="file-id",
|
||||||
|
icon_background=None,
|
||||||
|
description=None,
|
||||||
|
copyright=None,
|
||||||
|
privacy_policy=None,
|
||||||
|
custom_disclaimer=None,
|
||||||
|
default_language="en-US",
|
||||||
|
show_workflow_steps=True,
|
||||||
|
use_icon_as_answer_icon=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("controllers.common.fields.file_helpers.get_signed_file_url") as mock_helper:
|
||||||
|
model = Site.model_validate(site)
|
||||||
|
|
||||||
|
assert model.icon_url is None
|
||||||
|
mock_helper.assert_not_called()
|
||||||
Reference in New Issue
Block a user