refactor: migrate some ns.model to BaseModel (#30388)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2026-01-01 01:38:12 +09:00
committed by GitHub
parent e3ef33366d
commit 5b02e5dcb6
19 changed files with 168 additions and 99 deletions

View File

@ -1,62 +1,59 @@
from flask_restx import Api, Namespace, fields from __future__ import annotations
from libs.helper import AppIconUrlField from typing import Any, TypeAlias
parameters__system_parameters = { from pydantic import BaseModel, ConfigDict, computed_field
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer, from core.file import helpers as file_helpers
"audio_file_size_limit": fields.Integer, from models.model import IconType
"file_size_limit": fields.Integer,
"workflow_file_upload_limit": fields.Integer, JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
} JSONObject: TypeAlias = dict[str, Any]
def build_system_parameters_model(api_or_ns: Api | Namespace): class SystemParameters(BaseModel):
"""Build the system parameters model for the API or Namespace.""" image_file_size_limit: int
return api_or_ns.model("SystemParameters", parameters__system_parameters) video_file_size_limit: int
audio_file_size_limit: int
file_size_limit: int
workflow_file_upload_limit: int
parameters_fields = { class Parameters(BaseModel):
"opening_statement": fields.String, opening_statement: str | None = None
"suggested_questions": fields.Raw, suggested_questions: list[str]
"suggested_questions_after_answer": fields.Raw, suggested_questions_after_answer: JSONObject
"speech_to_text": fields.Raw, speech_to_text: JSONObject
"text_to_speech": fields.Raw, text_to_speech: JSONObject
"retriever_resource": fields.Raw, retriever_resource: JSONObject
"annotation_reply": fields.Raw, annotation_reply: JSONObject
"more_like_this": fields.Raw, more_like_this: JSONObject
"user_input_form": fields.Raw, user_input_form: list[JSONObject]
"sensitive_word_avoidance": fields.Raw, sensitive_word_avoidance: JSONObject
"file_upload": fields.Raw, file_upload: JSONObject
"system_parameters": fields.Nested(parameters__system_parameters), system_parameters: SystemParameters
}
def build_parameters_model(api_or_ns: Api | Namespace): class Site(BaseModel):
"""Build the parameters model for the API or Namespace.""" model_config = ConfigDict(from_attributes=True)
copied_fields = parameters_fields.copy()
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
return api_or_ns.model("Parameters", copied_fields)
title: str
chat_color_theme: str | None = None
chat_color_theme_inverted: bool
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
description: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
default_language: str
show_workflow_steps: bool
use_icon_as_answer_icon: bool
site_fields = { @computed_field(return_type=str | None) # type: ignore
"title": fields.String, @property
"chat_color_theme": fields.String, def icon_url(self) -> str | None:
"chat_color_theme_inverted": fields.Boolean, if self.icon and self.icon_type == IconType.IMAGE:
"icon_type": fields.String, return file_helpers.get_signed_file_url(self.icon)
"icon": fields.String, return None
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"description": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"default_language": fields.String,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}
def build_site_model(api_or_ns: Api | Namespace):
"""Build the site model for the API or Namespace."""
return api_or_ns.model("Site", site_fields)

View File

@ -1,5 +1,3 @@
from flask_restx import marshal_with
from controllers.common import fields from controllers.common import fields
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError from controllers.console.app.error import AppUnavailableError
@ -13,7 +11,6 @@ from services.app_service import AppService
class AppParameterApi(InstalledAppResource): class AppParameterApi(InstalledAppResource):
"""Resource for app variables.""" """Resource for app variables."""
@marshal_with(fields.parameters_fields)
def get(self, installed_app: InstalledApp): def get(self, installed_app: InstalledApp):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = installed_app.app app_model = installed_app.app
@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta") @console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@ -1,7 +1,7 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_restx import Api, Namespace, Resource, fields from flask_restx import Namespace, Resource, fields
from flask_restx.api import HTTPStatus from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -92,7 +92,7 @@ annotation_list_fields = {
} }
def build_annotation_list_model(api_or_ns: Api | Namespace): def build_annotation_list_model(api_or_ns: Namespace):
"""Build the annotation list model for the API or Namespace.""" """Build the annotation list model for the API or Namespace."""
copied_annotation_list_fields = annotation_list_fields.copy() copied_annotation_list_fields = annotation_list_fields.copy()
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))

View File

@ -1,6 +1,6 @@
from flask_restx import Resource from flask_restx import Resource
from controllers.common.fields import build_parameters_model from controllers.common.fields import Parameters
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
@ -23,7 +23,6 @@ class AppParameterApi(Resource):
} }
) )
@validate_app_token @validate_app_token
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
def get(self, app_model: App): def get(self, app_model: App):
"""Retrieve app parameters. """Retrieve app parameters.
@ -45,7 +44,8 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return Parameters.model_validate(parameters).model_dump(mode="json")
@service_api_ns.route("/meta") @service_api_ns.route("/meta")

View File

@ -1,7 +1,7 @@
from flask_restx import Resource from flask_restx import Resource
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.common.fields import build_site_model from controllers.common.fields import Site as SiteResponse
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db from extensions.ext_database import db
@ -23,7 +23,6 @@ class AppSiteApi(Resource):
} }
) )
@validate_app_token @validate_app_token
@service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model: App): def get(self, app_model: App):
"""Retrieve app site info. """Retrieve app site info.
@ -38,4 +37,4 @@ class AppSiteApi(Resource):
if app_model.tenant.status == TenantStatus.ARCHIVE: if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden() raise Forbidden()
return site return SiteResponse.model_validate(site).model_dump(mode="json")

View File

@ -3,7 +3,7 @@ from typing import Any, Literal
from dateutil.parser import isoparse from dateutil.parser import isoparse
from flask import request from flask import request
from flask_restx import Api, Namespace, Resource, fields from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
@ -78,7 +78,7 @@ workflow_run_fields = {
} }
def build_workflow_run_model(api_or_ns: Api | Namespace): def build_workflow_run_model(api_or_ns: Namespace):
"""Build the workflow run model for the API or Namespace.""" """Build the workflow run model for the API or Namespace."""
return api_or_ns.model("WorkflowRun", workflow_run_fields) return api_or_ns.model("WorkflowRun", workflow_run_fields)

View File

@ -1,7 +1,7 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, marshal_with from flask_restx import Resource
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
500: "Internal Server Error", 500: "Internal Server Error",
} }
) )
@marshal_with(fields.parameters_fields)
def get(self, app_model: App, end_user): def get(self, app_model: App, end_user):
"""Retrieve app parameters.""" """Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
@web_ns.route("/meta") @web_ns.route("/meta")

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields from flask_restx import Namespace, fields
from libs.helper import TimestampField from libs.helper import TimestampField
@ -12,7 +12,7 @@ annotation_fields = {
} }
def build_annotation_model(api_or_ns: Api | Namespace): def build_annotation_model(api_or_ns: Namespace):
"""Build the annotation model for the API or Namespace.""" """Build the annotation model for the API or Namespace."""
return api_or_ns.model("Annotation", annotation_fields) return api_or_ns.model("Annotation", annotation_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields from flask_restx import Namespace, fields
from fields.member_fields import simple_account_fields from fields.member_fields import simple_account_fields
from libs.helper import TimestampField from libs.helper import TimestampField
@ -46,7 +46,7 @@ message_file_fields = {
} }
def build_message_file_model(api_or_ns: Api | Namespace): def build_message_file_model(api_or_ns: Namespace):
"""Build the message file fields for the API or Namespace.""" """Build the message file fields for the API or Namespace."""
return api_or_ns.model("MessageFile", message_file_fields) return api_or_ns.model("MessageFile", message_file_fields)
@ -217,7 +217,7 @@ conversation_infinite_scroll_pagination_fields = {
} }
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the conversation infinite scroll pagination model for the API or Namespace.""" """Build the conversation infinite scroll pagination model for the API or Namespace."""
simple_conversation_model = build_simple_conversation_model(api_or_ns) simple_conversation_model = build_simple_conversation_model(api_or_ns)
@ -226,11 +226,11 @@ def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespa
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields) return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
def build_conversation_delete_model(api_or_ns: Api | Namespace): def build_conversation_delete_model(api_or_ns: Namespace):
"""Build the conversation delete model for the API or Namespace.""" """Build the conversation delete model for the API or Namespace."""
return api_or_ns.model("ConversationDelete", conversation_delete_fields) return api_or_ns.model("ConversationDelete", conversation_delete_fields)
def build_simple_conversation_model(api_or_ns: Api | Namespace): def build_simple_conversation_model(api_or_ns: Namespace):
"""Build the simple conversation model for the API or Namespace.""" """Build the simple conversation model for the API or Namespace."""
return api_or_ns.model("SimpleConversation", simple_conversation_fields) return api_or_ns.model("SimpleConversation", simple_conversation_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields from flask_restx import Namespace, fields
from libs.helper import TimestampField from libs.helper import TimestampField
@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
} }
def build_conversation_variable_model(api_or_ns: Api | Namespace): def build_conversation_variable_model(api_or_ns: Namespace):
"""Build the conversation variable model for the API or Namespace.""" """Build the conversation variable model for the API or Namespace."""
return api_or_ns.model("ConversationVariable", conversation_variable_fields) return api_or_ns.model("ConversationVariable", conversation_variable_fields)
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the conversation variable infinite scroll pagination model for the API or Namespace.""" """Build the conversation variable infinite scroll pagination model for the API or Namespace."""
# Build the nested variable model first # Build the nested variable model first
conversation_variable_model = build_conversation_variable_model(api_or_ns) conversation_variable_model = build_conversation_variable_model(api_or_ns)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields from flask_restx import Namespace, fields
simple_end_user_fields = { simple_end_user_fields = {
"id": fields.String, "id": fields.String,
@ -8,5 +8,5 @@ simple_end_user_fields = {
} }
def build_simple_end_user_model(api_or_ns: Api | Namespace): def build_simple_end_user_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleEndUser", simple_end_user_fields) return api_or_ns.model("SimpleEndUser", simple_end_user_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields from flask_restx import Namespace, fields
from libs.helper import TimestampField from libs.helper import TimestampField
@ -14,7 +14,7 @@ upload_config_fields = {
} }
def build_upload_config_model(api_or_ns: Api | Namespace): def build_upload_config_model(api_or_ns: Namespace):
"""Build the upload config model for the API or Namespace. """Build the upload config model for the API or Namespace.
Args: Args:
@ -39,7 +39,7 @@ file_fields = {
} }
def build_file_model(api_or_ns: Api | Namespace): def build_file_model(api_or_ns: Namespace):
"""Build the file model for the API or Namespace. """Build the file model for the API or Namespace.
Args: Args:
@ -57,7 +57,7 @@ remote_file_info_fields = {
} }
def build_remote_file_info_model(api_or_ns: Api | Namespace): def build_remote_file_info_model(api_or_ns: Namespace):
"""Build the remote file info model for the API or Namespace. """Build the remote file info model for the API or Namespace.
Args: Args:
@ -81,7 +81,7 @@ file_fields_with_signed_url = {
} }
def build_file_with_signed_url_model(api_or_ns: Api | Namespace): def build_file_with_signed_url_model(api_or_ns: Namespace):
"""Build the file with signed URL model for the API or Namespace. """Build the file with signed URL model for the API or Namespace.
Args: Args:

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields from flask_restx import Namespace, fields
from libs.helper import AvatarUrlField, TimestampField from libs.helper import AvatarUrlField, TimestampField
@ -9,7 +9,7 @@ simple_account_fields = {
} }
def build_simple_account_model(api_or_ns: Api | Namespace): def build_simple_account_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleAccount", simple_account_fields) return api_or_ns.model("SimpleAccount", simple_account_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields from flask_restx import Namespace, fields
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField from libs.helper import TimestampField
@ -10,7 +10,7 @@ feedback_fields = {
} }
def build_feedback_model(api_or_ns: Api | Namespace): def build_feedback_model(api_or_ns: Namespace):
"""Build the feedback model for the API or Namespace.""" """Build the feedback model for the API or Namespace."""
return api_or_ns.model("Feedback", feedback_fields) return api_or_ns.model("Feedback", feedback_fields)
@ -30,7 +30,7 @@ agent_thought_fields = {
} }
def build_agent_thought_model(api_or_ns: Api | Namespace): def build_agent_thought_model(api_or_ns: Namespace):
"""Build the agent thought model for the API or Namespace.""" """Build the agent thought model for the API or Namespace."""
return api_or_ns.model("AgentThought", agent_thought_fields) return api_or_ns.model("AgentThought", agent_thought_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields from flask_restx import Namespace, fields
dataset_tag_fields = { dataset_tag_fields = {
"id": fields.String, "id": fields.String,
@ -8,5 +8,5 @@ dataset_tag_fields = {
} }
def build_dataset_tag_fields(api_or_ns: Api | Namespace): def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields) return api_or_ns.model("DataSetTag", dataset_tag_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields from flask_restx import Namespace, fields
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
from fields.member_fields import build_simple_account_model, simple_account_fields from fields.member_fields import build_simple_account_model, simple_account_fields
@ -17,7 +17,7 @@ workflow_app_log_partial_fields = {
} }
def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace): def build_workflow_app_log_partial_model(api_or_ns: Namespace):
"""Build the workflow app log partial model for the API or Namespace.""" """Build the workflow app log partial model for the API or Namespace."""
workflow_run_model = build_workflow_run_for_log_model(api_or_ns) workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
simple_account_model = build_simple_account_model(api_or_ns) simple_account_model = build_simple_account_model(api_or_ns)
@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = {
} }
def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace): def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
"""Build the workflow app log pagination model for the API or Namespace.""" """Build the workflow app log pagination model for the API or Namespace."""
# Build the nested partial model first # Build the nested partial model first
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns) workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields from flask_restx import Namespace, fields
from fields.end_user_fields import simple_end_user_fields from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields from fields.member_fields import simple_account_fields
@ -19,7 +19,7 @@ workflow_run_for_log_fields = {
} }
def build_workflow_run_for_log_model(api_or_ns: Api | Namespace): def build_workflow_run_for_log_model(api_or_ns: Namespace):
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)

View File

@ -8,7 +8,7 @@ from uuid import uuid4
import sqlalchemy as sa import sqlalchemy as sa
from flask_login import UserMixin from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column, validates
from typing_extensions import deprecated from typing_extensions import deprecated
from .base import TypeBase from .base import TypeBase
@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase):
role: TenantAccountRole | None = field(default=None, init=False) role: TenantAccountRole | None = field(default=None, init=False)
_current_tenant: "Tenant | None" = field(default=None, init=False) _current_tenant: "Tenant | None" = field(default=None, init=False)
@validates("status")
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
if isinstance(value, AccountStatus):
return value.value
return value
@property @property
def is_password_set(self): def is_password_set(self):
return self.password is not None return self.password is not None

View File

@ -0,0 +1,69 @@
import builtins
from types import SimpleNamespace
from unittest.mock import patch
from flask.views import MethodView as FlaskMethodView
_NEEDS_METHOD_VIEW_CLEANUP = False
if not hasattr(builtins, "MethodView"):
builtins.MethodView = FlaskMethodView
_NEEDS_METHOD_VIEW_CLEANUP = True
from controllers.common.fields import Parameters, Site
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import IconType
def test_parameters_model_round_trip():
parameters = get_parameters_from_feature_dict(features_dict={}, user_input_form=[])
model = Parameters.model_validate(parameters)
assert model.model_dump(mode="json") == parameters
def test_site_icon_url_uses_signed_url_for_image_icon():
site = SimpleNamespace(
title="Example",
chat_color_theme=None,
chat_color_theme_inverted=False,
icon_type=IconType.IMAGE,
icon="file-id",
icon_background=None,
description=None,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
default_language="en-US",
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
with patch("controllers.common.fields.file_helpers.get_signed_file_url", return_value="signed") as mock_helper:
model = Site.model_validate(site)
assert model.icon_url == "signed"
mock_helper.assert_called_once_with("file-id")
def test_site_icon_url_is_none_for_non_image_icon():
site = SimpleNamespace(
title="Example",
chat_color_theme=None,
chat_color_theme_inverted=False,
icon_type=IconType.EMOJI,
icon="file-id",
icon_background=None,
description=None,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
default_language="en-US",
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
with patch("controllers.common.fields.file_helpers.get_signed_file_url") as mock_helper:
model = Site.model_validate(site)
assert model.icon_url is None
mock_helper.assert_not_called()