mirror of
https://github.com/langgenius/dify.git
synced 2026-03-05 07:37:07 +08:00
Merge remote-tracking branch 'upstream/main' into feat/human-input-merge-again
This commit is contained in:
@ -1,78 +1,140 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField
|
||||
from datetime import datetime
|
||||
from typing import TypeAlias
|
||||
from uuid import uuid4
|
||||
|
||||
from .raws import FilesContainedField
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
feedback_fields = {
|
||||
"rating": fields.String,
|
||||
}
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
from core.file import File
|
||||
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
|
||||
|
||||
JSONValueType: TypeAlias = JSONValue
|
||||
|
||||
|
||||
def build_feedback_model(api_or_ns: Api | Namespace):
|
||||
"""Build the feedback model for the API or Namespace."""
|
||||
return api_or_ns.model("Feedback", feedback_fields)
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
|
||||
|
||||
agent_thought_fields = {
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
}
|
||||
class SimpleFeedback(ResponseModel):
|
||||
rating: str | None = None
|
||||
|
||||
|
||||
def build_agent_thought_model(api_or_ns: Api | Namespace):
|
||||
"""Build the agent thought model for the API or Namespace."""
|
||||
return api_or_ns.model("AgentThought", agent_thought_fields)
|
||||
class RetrieverResource(ResponseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
message_id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
position: int
|
||||
dataset_id: str | None = None
|
||||
dataset_name: str | None = None
|
||||
document_id: str | None = None
|
||||
document_name: str | None = None
|
||||
data_source_type: str | None = None
|
||||
segment_id: str | None = None
|
||||
score: float | None = None
|
||||
hit_count: int | None = None
|
||||
word_count: int | None = None
|
||||
segment_position: int | None = None
|
||||
index_node_hash: str | None = None
|
||||
content: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
retriever_resource_fields = {
|
||||
"id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"dataset_id": fields.String,
|
||||
"dataset_name": fields.String,
|
||||
"document_id": fields.String,
|
||||
"document_name": fields.String,
|
||||
"data_source_type": fields.String,
|
||||
"segment_id": fields.String,
|
||||
"score": fields.Float,
|
||||
"hit_count": fields.Integer,
|
||||
"word_count": fields.Integer,
|
||||
"segment_position": fields.Integer,
|
||||
"index_node_hash": fields.String,
|
||||
"content": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
class MessageListItem(ResponseModel):
|
||||
id: str
|
||||
conversation_id: str
|
||||
parent_message_id: str | None = None
|
||||
inputs: dict[str, JSONValueType]
|
||||
query: str
|
||||
answer: str = Field(validation_alias="re_sign_file_url_answer")
|
||||
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
|
||||
retriever_resources: list[RetrieverResource]
|
||||
created_at: int | None = None
|
||||
agent_thoughts: list[AgentThought]
|
||||
message_files: list[MessageFile]
|
||||
status: str
|
||||
error: str | None = None
|
||||
extra_contents: list[ExecutionExtraContentDomainModel]
|
||||
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"extra_contents": fields.List(cls_or_instance=fields.Raw),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
|
||||
return format_files_contained(value)
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class WebMessageListItem(MessageListItem):
|
||||
metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict")
|
||||
|
||||
|
||||
class MessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[MessageListItem]
|
||||
|
||||
|
||||
class WebMessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[WebMessageListItem]
|
||||
|
||||
|
||||
class SavedMessageItem(ResponseModel):
|
||||
id: str
|
||||
inputs: dict[str, JSONValueType]
|
||||
query: str
|
||||
answer: str
|
||||
message_files: list[MessageFile]
|
||||
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
|
||||
return format_files_contained(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
|
||||
|
||||
class SavedMessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[SavedMessageItem]
|
||||
|
||||
|
||||
class SuggestedQuestionsResponse(ResponseModel):
|
||||
data: list[str]
|
||||
|
||||
|
||||
def to_timestamp(value: datetime | None) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def format_files_contained(value: JSONValueType) -> JSONValueType:
|
||||
if isinstance(value, File):
|
||||
return value.model_dump()
|
||||
if isinstance(value, dict):
|
||||
return {k: format_files_contained(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [format_files_contained(v) for v in value]
|
||||
return value
|
||||
|
||||
Reference in New Issue
Block a user