feat: chat messages api support parent message id

This commit is contained in:
yyh
2025-12-25 15:21:44 +08:00
parent fb14644a79
commit df09acb74b
14 changed files with 213 additions and 26 deletions

View File

@ -4,7 +4,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
@ -33,8 +33,11 @@ from libs import helper
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.conversation_service import ConversationService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import MessageNotExistsError
from services.message_service import MessageService
logger = logging.getLogger(__name__)
@ -53,14 +56,18 @@ class ChatRequestPayload(BaseModel):
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID")
parent_message_id: str | None = Field(default=None, description="Parent message UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
@field_validator("conversation_id", mode="before")
@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
def normalize_uuid_fields(cls, value: str | UUID | None, info: ValidationInfo) -> str | None:
"""Allow missing or blank UUID fields; enforce UUID format when provided."""
if isinstance(value, UUID):
return str(value)
if isinstance(value, str):
value = value.strip()
@ -70,7 +77,36 @@ class ChatRequestPayload(BaseModel):
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
raise ValueError(f"{info.field_name} must be a valid UUID") from exc
def _validate_parent_message_request(
*,
app_model: App,
end_user: EndUser,
conversation_id: str | None,
parent_message_id: str | None,
) -> None:
if not parent_message_id:
return
if not conversation_id:
raise BadRequest("conversation_id is required when parent_message_id is provided.")
try:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=end_user
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
try:
parent_message = MessageService.get_message(app_model=app_model, user=end_user, message_id=parent_message_id)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
if parent_message.conversation_id != conversation.id:
raise BadRequest("parent_message_id does not belong to the conversation.")
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
@ -205,6 +241,13 @@ class ChatApi(Resource):
streaming = payload.response_mode == "streaming"
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id=args.get("conversation_id"),
parent_message_id=args.get("parent_message_id"),
)
try:
response = AppGenerateService.generate(
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming

View File

@ -12,7 +12,6 @@ from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
@ -168,7 +167,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@ -9,7 +9,6 @@ from flask import Flask, current_app
from pydantic import ValidationError
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
@ -163,7 +162,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@ -8,7 +8,6 @@ from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -156,7 +155,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
user_id=user.id,
invoke_from=invoke_from,
extras=extras,

View File

@ -1,11 +1,12 @@
import json
import logging
from collections.abc import Generator
from typing import Union, cast
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -84,6 +85,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
raise e
def _resolve_parent_message_id(self, args: Mapping[str, Any], invoke_from: InvokeFrom) -> str | None:
parent_message_id = args.get("parent_message_id")
if invoke_from == InvokeFrom.SERVICE_API and not parent_message_id:
return UUID_NIL
return parent_message_id
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
if conversation:
stmt = select(AppModelConfig).where(

View File

@ -2,9 +2,8 @@ from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from pydantic import BaseModel, ConfigDict, Field
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig
@ -158,20 +157,12 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
parent_message_id: str | None = Field(
default=None,
description=(
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
"For service API, we need to ensure its forward compatibility, "
"so passing in the parent_message_id as request arg is not supported for now. "
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
"Starting from v0.9.0, parent_message_id is used to support message regeneration "
"and branching in chat APIs."
"For service API, when it is omitted, the system treats it as UUID_NIL to preserve legacy linear history."
),
)
@field_validator("parent_message_id")
@classmethod
def validate_parent_message_id(cls, v, info: ValidationInfo):
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
raise ValueError("parent_message_id should be UUID_NIL for service API")
return v
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
"""

View File

@ -0,0 +1,110 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from controllers.service_api.app.completion import _validate_parent_message_request
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
def test_validate_parent_message_skips_when_missing():
app_model = object()
end_user = object()
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation") as get_conversation,
patch("controllers.service_api.app.completion.MessageService.get_message") as get_message,
):
_validate_parent_message_request(
app_model=app_model, end_user=end_user, conversation_id=None, parent_message_id=None
)
get_conversation.assert_not_called()
get_message.assert_not_called()
def test_validate_parent_message_requires_conversation_id():
app_model = object()
end_user = object()
with pytest.raises(BadRequest):
_validate_parent_message_request(
app_model=app_model, end_user=end_user, conversation_id=None, parent_message_id="parent-id"
)
def test_validate_parent_message_missing_conversation_raises_not_found():
app_model = object()
end_user = object()
with patch(
"controllers.service_api.app.completion.ConversationService.get_conversation",
side_effect=ConversationNotExistsError(),
):
with pytest.raises(NotFound):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)
def test_validate_parent_message_missing_message_raises_not_found():
app_model = object()
end_user = object()
conversation = SimpleNamespace(id="conversation-id")
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
patch(
"controllers.service_api.app.completion.MessageService.get_message",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)
def test_validate_parent_message_mismatch_conversation_raises_bad_request():
app_model = object()
end_user = object()
conversation = SimpleNamespace(id="conversation-id")
message = SimpleNamespace(conversation_id="different-id")
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
patch("controllers.service_api.app.completion.MessageService.get_message", return_value=message),
):
with pytest.raises(BadRequest):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)
def test_validate_parent_message_matches_conversation():
app_model = object()
end_user = object()
conversation = SimpleNamespace(id="conversation-id")
message = SimpleNamespace(conversation_id="conversation-id")
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
patch("controllers.service_api.app.completion.MessageService.get_message", return_value=message),
):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)

View File

@ -23,3 +23,24 @@ def test_chat_request_payload_validates_uuid():
def test_chat_request_payload_rejects_invalid_uuid():
with pytest.raises(ValidationError):
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"})
def test_chat_request_payload_accepts_blank_parent_message_id():
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "parent_message_id": ""})
assert payload.parent_message_id is None
def test_chat_request_payload_validates_parent_message_id_uuid():
parent_message_id = str(uuid.uuid4())
payload = ChatRequestPayload.model_validate(
{"inputs": {}, "query": "hello", "parent_message_id": parent_message_id}
)
assert payload.parent_message_id == parent_message_id
def test_chat_request_payload_rejects_invalid_parent_message_id():
with pytest.raises(ValidationError):
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "parent_message_id": "invalid"})