improve: mordernizing validation by migrating pydantic from 1.x to 2.x (#4592)

This commit is contained in:
Bowen Liang
2024-06-14 01:05:37 +08:00
committed by GitHub
parent e8afc416dd
commit f976740b57
87 changed files with 697 additions and 300 deletions

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
@ -62,6 +62,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
parameters: dict[str, Any] = {}
stop: list[str] = []
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class AppGenerateEntity(BaseModel):
"""
@ -93,10 +96,13 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: EasyUIBasedAppConfig
model_config: ModelConfigWithCredentialsEntity
model_conf: ModelConfigWithCredentialsEntity
query: Optional[str] = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""

View File

@ -1,14 +1,14 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class QueueEvent(Enum):
class QueueEvent(str, Enum):
"""
QueueEvent enum
"""
@ -47,14 +47,14 @@ class QueueLLMChunkEvent(AppQueueEvent):
"""
QueueLLMChunkEvent entity
"""
event = QueueEvent.LLM_CHUNK
event: QueueEvent = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
class QueueIterationStartEvent(AppQueueEvent):
"""
QueueIterationStartEvent entity
"""
event = QueueEvent.ITERATION_START
event: QueueEvent = QueueEvent.ITERATION_START
node_id: str
node_type: NodeType
node_data: BaseNodeData
@ -68,16 +68,17 @@ class QueueIterationNextEvent(AppQueueEvent):
"""
QueueIterationNextEvent entity
"""
event = QueueEvent.ITERATION_NEXT
event: QueueEvent = QueueEvent.ITERATION_NEXT
index: int
node_id: str
node_type: NodeType
node_run_index: int
output: Optional[Any] # output for the current iteration
output: Optional[Any] = None # output for the current iteration
@validator('output', pre=True, always=True)
@classmethod
@field_validator('output', mode='before')
def set_output(cls, v):
"""
Set output
@ -92,7 +93,7 @@ class QueueIterationCompletedEvent(AppQueueEvent):
"""
QueueIterationCompletedEvent entity
"""
event = QueueEvent.ITERATION_COMPLETED
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
node_id: str
node_type: NodeType
@ -104,7 +105,7 @@ class QueueTextChunkEvent(AppQueueEvent):
"""
QueueTextChunkEvent entity
"""
event = QueueEvent.TEXT_CHUNK
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
metadata: Optional[dict] = None
@ -113,7 +114,7 @@ class QueueAgentMessageEvent(AppQueueEvent):
"""
QueueMessageEvent entity
"""
event = QueueEvent.AGENT_MESSAGE
event: QueueEvent = QueueEvent.AGENT_MESSAGE
chunk: LLMResultChunk
@ -121,7 +122,7 @@ class QueueMessageReplaceEvent(AppQueueEvent):
"""
QueueMessageReplaceEvent entity
"""
event = QueueEvent.MESSAGE_REPLACE
event: QueueEvent = QueueEvent.MESSAGE_REPLACE
text: str
@ -129,7 +130,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
QueueRetrieverResourcesEvent entity
"""
event = QueueEvent.RETRIEVER_RESOURCES
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
@ -137,7 +138,7 @@ class QueueAnnotationReplyEvent(AppQueueEvent):
"""
QueueAnnotationReplyEvent entity
"""
event = QueueEvent.ANNOTATION_REPLY
event: QueueEvent = QueueEvent.ANNOTATION_REPLY
message_annotation_id: str
@ -145,7 +146,7 @@ class QueueMessageEndEvent(AppQueueEvent):
"""
QueueMessageEndEvent entity
"""
event = QueueEvent.MESSAGE_END
event: QueueEvent = QueueEvent.MESSAGE_END
llm_result: Optional[LLMResult] = None
@ -153,28 +154,28 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
"""
QueueAdvancedChatMessageEndEvent entity
"""
event = QueueEvent.ADVANCED_CHAT_MESSAGE_END
event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END
class QueueWorkflowStartedEvent(AppQueueEvent):
"""
QueueWorkflowStartedEvent entity
"""
event = QueueEvent.WORKFLOW_STARTED
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
class QueueWorkflowSucceededEvent(AppQueueEvent):
"""
QueueWorkflowSucceededEvent entity
"""
event = QueueEvent.WORKFLOW_SUCCEEDED
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
class QueueWorkflowFailedEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""
event = QueueEvent.WORKFLOW_FAILED
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str
@ -182,7 +183,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
"""
QueueNodeStartedEvent entity
"""
event = QueueEvent.NODE_STARTED
event: QueueEvent = QueueEvent.NODE_STARTED
node_id: str
node_type: NodeType
@ -195,7 +196,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""
QueueNodeSucceededEvent entity
"""
event = QueueEvent.NODE_SUCCEEDED
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
node_id: str
node_type: NodeType
@ -213,7 +214,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
"""
event = QueueEvent.NODE_FAILED
event: QueueEvent = QueueEvent.NODE_FAILED
node_id: str
node_type: NodeType
@ -230,7 +231,7 @@ class QueueAgentThoughtEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event = QueueEvent.AGENT_THOUGHT
event: QueueEvent = QueueEvent.AGENT_THOUGHT
agent_thought_id: str
@ -238,7 +239,7 @@ class QueueMessageFileEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event = QueueEvent.MESSAGE_FILE
event: QueueEvent = QueueEvent.MESSAGE_FILE
message_file_id: str
@ -246,15 +247,15 @@ class QueueErrorEvent(AppQueueEvent):
"""
QueueErrorEvent entity
"""
event = QueueEvent.ERROR
error: Any
event: QueueEvent = QueueEvent.ERROR
error: Any = None
class QueuePingEvent(AppQueueEvent):
"""
QueuePingEvent entity
"""
event = QueueEvent.PING
event: QueueEvent = QueueEvent.PING
class QueueStopEvent(AppQueueEvent):
@ -270,7 +271,7 @@ class QueueStopEvent(AppQueueEvent):
OUTPUT_MODERATION = "output-moderation"
INPUT_MODERATION = "input-moderation"
event = QueueEvent.STOP
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
@ -118,9 +118,7 @@ class ErrorStreamResponse(StreamResponse):
"""
event: StreamEvent = StreamEvent.ERROR
err: Exception
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
class MessageStreamResponse(StreamResponse):
@ -360,7 +358,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
pre_iteration_output: Optional[Any]
pre_iteration_output: Optional[Any] = None
extras: dict = {}
event: StreamEvent = StreamEvent.ITERATION_NEXT
@ -379,12 +377,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[dict]
outputs: Optional[dict] = None
created_at: int
extras: dict = None
inputs: dict = None
status: WorkflowNodeExecutionStatus
error: Optional[str]
error: Optional[str] = None
elapsed_time: float
total_tokens: int
finished_at: int