mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 06:58:05 +08:00
fix(api): missing site field in Web App Form Definition API
This commit is contained in:
@ -2,23 +2,32 @@
|
||||
Web App Human Input Form APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError
|
||||
from controllers.web.site import serialize_site
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from models.workflow import WorkflowRun
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
"""Return the Pydantic definition as a JSON response."""
|
||||
return Response(form.get_definition().model_dump_json(), mimetype="application/json")
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the Pydantic definition (optionally with site) as a JSON response."""
|
||||
payload = form.get_definition().model_dump()
|
||||
if site_payload is not None:
|
||||
payload["site"] = site_payload
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
# TODO(QuantumGhost): disable authorization for web app
|
||||
@ -46,7 +55,9 @@ class HumanInputFormApi(Resource):
|
||||
if form is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
return _jsonify_form_definition(form)
|
||||
site = _get_site_from_form(form)
|
||||
|
||||
return _jsonify_form_definition(form, site_payload=serialize_site(site))
|
||||
|
||||
# def post(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def post(self, form_token: str):
|
||||
@ -82,3 +93,25 @@ class HumanInputFormApi(Resource):
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
return {}, 200
|
||||
|
||||
|
||||
def _get_site_from_form(form: Form) -> Site:
|
||||
"""Resolve Site for the form's workflow run and validate tenant status."""
|
||||
workflow_run = (
|
||||
db.session.query(WorkflowRun).where(WorkflowRun.id == form.workflow_run_id).first()
|
||||
)
|
||||
if workflow_run is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
app_model = db.session.query(App).where(App.id == workflow_run.app_id).first()
|
||||
if app_model is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if site is None:
|
||||
raise Forbidden()
|
||||
|
||||
if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import fields, marshal_with
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@ -104,7 +104,12 @@ class AppSiteInfo:
|
||||
if tenant.custom_config_dict.get("replace_webapp_logo")
|
||||
else None
|
||||
)
|
||||
self.custom_config = {
|
||||
"remove_webapp_brand": remove_webapp_brand,
|
||||
"replace_webapp_logo": replace_webapp_logo,
|
||||
}
|
||||
self.custom_config = {
|
||||
"remove_webapp_brand": remove_webapp_brand,
|
||||
"replace_webapp_logo": replace_webapp_logo,
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return marshal(site, AppSiteApi.site_fields)
|
||||
|
||||
143
api/tests/unit_tests/controllers/web/test_human_input_form.py
Normal file
143
api/tests/unit_tests/controllers/web/test_human_input_form.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""Unit tests for controllers.web.human_input_form endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import controllers.web.human_input_form as human_input_module
|
||||
|
||||
HumanInputFormApi = human_input_module.HumanInputFormApi
|
||||
RecipientType = human_input_module.RecipientType
|
||||
TenantStatus = human_input_module.TenantStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Configure a minimal Flask app for request contexts."""
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Simple stand-in for db.session that returns pre-seeded objects."""
|
||||
|
||||
def __init__(self, mapping: dict[str, Any]):
|
||||
self._mapping = mapping
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model):
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *args, **kwargs): # noqa: ANN002, ANN003
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
|
||||
class _FakeDB:
|
||||
"""Minimal db stub exposing engine and session."""
|
||||
|
||||
def __init__(self, session: _FakeSession):
|
||||
self.session = session
|
||||
self.engine = object()
|
||||
|
||||
|
||||
def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""GET returns form definition merged with site payload."""
|
||||
|
||||
class _FakeDefinition:
|
||||
def model_dump(self):
|
||||
return {"form_content": "hello"}
|
||||
|
||||
class _FakeForm:
|
||||
workflow_run_id = "workflow-1"
|
||||
|
||||
def get_definition(self):
|
||||
return _FakeDefinition()
|
||||
|
||||
form = _FakeForm()
|
||||
|
||||
tenant = SimpleNamespace(status=TenantStatus.NORMAL)
|
||||
app_model = SimpleNamespace(id="app-1", tenant=tenant)
|
||||
workflow_run = SimpleNamespace(app_id="app-1")
|
||||
site_model = SimpleNamespace(
|
||||
title="My Site",
|
||||
icon_type="emoji",
|
||||
icon=None,
|
||||
icon_background="#fff",
|
||||
description="desc",
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
# Patch service to return fake form.
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_definition_by_token.return_value = form
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
|
||||
# Patch db session.
|
||||
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model}))
|
||||
monkeypatch.setattr(human_input_module, "db", db_stub)
|
||||
|
||||
# Patch serialize_site to a predictable value.
|
||||
monkeypatch.setattr(human_input_module, "serialize_site", lambda site: {"title": site.title})
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
response = HumanInputFormApi().get("token-1")
|
||||
|
||||
body = json.loads(response.get_data(as_text=True))
|
||||
assert body["form_content"] == "hello"
|
||||
assert body["site"] == {"title": "My Site"}
|
||||
service_mock.get_form_definition_by_token.assert_called_once_with(
|
||||
RecipientType.STANDALONE_WEB_APP,
|
||||
"token-1",
|
||||
)
|
||||
|
||||
|
||||
def test_get_form_raises_forbidden_when_site_missing(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""GET raises Forbidden if site cannot be resolved."""
|
||||
|
||||
class _FakeDefinition:
|
||||
def model_dump(self):
|
||||
return {"form_content": "hello"}
|
||||
|
||||
class _FakeForm:
|
||||
workflow_run_id = "workflow-1"
|
||||
|
||||
def get_definition(self):
|
||||
return _FakeDefinition()
|
||||
|
||||
form = _FakeForm()
|
||||
tenant = SimpleNamespace(status=TenantStatus.NORMAL)
|
||||
app_model = SimpleNamespace(id="app-1", tenant=tenant)
|
||||
workflow_run = SimpleNamespace(app_id="app-1")
|
||||
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_definition_by_token.return_value = form
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
|
||||
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": None}))
|
||||
monkeypatch.setattr(human_input_module, "db", db_stub)
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(Forbidden):
|
||||
HumanInputFormApi().get("token-1")
|
||||
Reference in New Issue
Block a user