Merge branch 'feat/rag-2' into feat/add-dataset-service-api-enable

# Conflicts:
#	api/controllers/console/datasets/datasets.py
#	api/controllers/service_api/wraps.py
#	api/services/dataset_service.py
This commit is contained in:
jyong
2025-09-16 15:21:23 +08:00
843 changed files with 25061 additions and 16010 deletions

View File

@ -1,12 +1,10 @@
from typing import Optional
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
if not sensitive_word_avoidance_dict:
return None
@ -21,7 +19,7 @@ class SensitiveWordAvoidanceConfigManager:
@classmethod
def validate_and_set_defaults(
cls, tenant_id, config: dict, only_structure_validate: bool = False
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False}
@ -38,7 +36,14 @@ class SensitiveWordAvoidanceConfigManager:
if not only_structure_validate:
typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
if not isinstance(typ, str):
raise ValueError("sensitive_word_avoidance.type must be a string")
sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config")
if sensitive_word_avoidance_config is None:
sensitive_word_avoidance_config = {}
if not isinstance(sensitive_word_avoidance_config, dict):
raise ValueError("sensitive_word_avoidance.config must be a dict")
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)

View File

@ -1,12 +1,10 @@
from typing import Optional
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
class AgentConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[AgentEntity]:
def convert(cls, config: dict) -> AgentEntity | None:
"""
Convert model config to model config

View File

@ -1,5 +1,4 @@
import uuid
from typing import Optional
from core.app.app_config.entities import (
DatasetEntity,
@ -14,7 +13,7 @@ from services.dataset_service import DatasetService
class DatasetConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[DatasetEntity]:
def convert(cls, config: dict) -> DatasetEntity | None:
"""
Convert model config to model config

View File

@ -25,10 +25,14 @@ class PromptTemplateConfigManager:
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
text = message.get("text")
if not isinstance(text, str):
raise ValueError("message text must be a string")
role = message.get("role")
if not isinstance(role, str):
raise ValueError("message role must be a string")
chat_prompt_messages.append(
AdvancedChatMessageEntity(
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role))
)
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
@ -66,7 +70,7 @@ class PromptTemplateConfigManager:
:param config: app model config args
"""
if not config.get("prompt_type"):
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config["prompt_type"] not in prompt_type_vals:
@ -86,7 +90,7 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Any, Literal, Optional
from enum import StrEnum, auto
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -17,7 +17,7 @@ class ModelConfigEntity(BaseModel):
provider: str
model: str
mode: Optional[str] = None
mode: str | None = None
parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = Field(default_factory=list)
@ -53,7 +53,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
assistant: str
prompt: str
role_prefix: Optional[RolePrefixEntity] = None
role_prefix: RolePrefixEntity | None = None
class PromptTemplateEntity(BaseModel):
@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel):
Prompt Template Entity.
"""
class PromptType(Enum):
class PromptType(StrEnum):
"""
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = "simple"
ADVANCED = "advanced"
SIMPLE = auto()
ADVANCED = auto()
@classmethod
def value_of(cls, value: str):
@ -84,9 +84,9 @@ class PromptTemplateEntity(BaseModel):
raise ValueError(f"invalid prompt type value {value}")
prompt_type: PromptType
simple_prompt_template: Optional[str] = None
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
simple_prompt_template: str | None = None
advanced_chat_prompt_template: AdvancedChatPromptTemplateEntity | None = None
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
class VariableEntityType(StrEnum):
@ -112,7 +112,7 @@ class VariableEntity(BaseModel):
type: VariableEntityType
required: bool = False
hide: bool = False
max_length: Optional[int] = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
@ -183,7 +183,7 @@ class ModelConfig(BaseModel):
class Condition(BaseModel):
"""
Conditon detail
Condition detail
"""
name: str
@ -196,8 +196,8 @@ class MetadataFilteringCondition(BaseModel):
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class DatasetRetrieveConfigEntity(BaseModel):
@ -205,14 +205,14 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Config Entity.
"""
class RetrieveStrategy(Enum):
class RetrieveStrategy(StrEnum):
"""
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = "single"
MULTIPLE = "multiple"
SINGLE = auto()
MULTIPLE = auto()
@classmethod
def value_of(cls, value: str):
@ -227,18 +227,18 @@ class DatasetRetrieveConfigEntity(BaseModel):
return mode
raise ValueError(f"invalid retrieve strategy value {value}")
query_variable: Optional[str] = None # Only when app mode is completion
query_variable: str | None = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None
score_threshold: Optional[float] = 0.0
rerank_mode: Optional[str] = "reranking_model"
reranking_model: Optional[dict] = None
weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
top_k: int | None = None
score_threshold: float | None = 0.0
rerank_mode: str | None = "reranking_model"
reranking_model: dict | None = None
weights: dict | None = None
reranking_enabled: bool | None = True
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
class DatasetEntity(BaseModel):
@ -265,8 +265,8 @@ class TextToSpeechEntity(BaseModel):
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
voice: str | None = None
language: str | None = None
class TracingConfigEntity(BaseModel):
@ -279,15 +279,15 @@ class TracingConfigEntity(BaseModel):
class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileUploadConfig] = None
opening_statement: Optional[str] = None
file_upload: FileUploadConfig | None = None
opening_statement: str | None = None
suggested_questions: list[str] = []
suggested_questions_after_answer: bool = False
show_retrieve_source: bool = False
more_like_this: bool = False
speech_to_text: bool = False
text_to_speech: Optional[TextToSpeechEntity] = None
trace_config: Optional[TracingConfigEntity] = None
text_to_speech: TextToSpeechEntity | None = None
trace_config: TracingConfigEntity | None = None
class AppConfig(BaseModel):
@ -300,15 +300,15 @@ class AppConfig(BaseModel):
app_mode: AppMode
additional_features: Optional[AppAdditionalFeatures] = None
variables: list[VariableEntity] = []
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
class EasyUIBasedAppModelConfigFrom(Enum):
class EasyUIBasedAppModelConfigFrom(StrEnum):
"""
App Model Config From.
"""
ARGS = "args"
ARGS = auto()
APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
@ -323,7 +323,7 @@ class EasyUIBasedAppConfig(AppConfig):
app_model_config_dict: dict
model: ModelConfigEntity
prompt_template: PromptTemplateEntity
dataset: Optional[DatasetEntity] = None
dataset: DatasetEntity | None = None
external_data_variables: list[ExternalDataVariableEntity] = []

View File

@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -390,7 +390,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None,
conversation: Conversation | None = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:

View File

@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -249,7 +249,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record

View File

@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -4,7 +4,7 @@ import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from threading import Thread
from typing import Any, Optional, Union
from typing import Any, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -169,7 +169,7 @@ class AdvancedChatAppGenerateTaskPipeline:
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -228,7 +228,7 @@ class AdvancedChatAppGenerateTaskPipeline:
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
@ -289,7 +289,7 @@ class AdvancedChatAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
@ -297,13 +297,13 @@ class AdvancedChatAppGenerateTaskPipeline:
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
with self._database_session() as session:
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline._error_to_stream_response(err)
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
@ -404,8 +404,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueTextChunkEvent,
*,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
@ -505,8 +505,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
@ -536,8 +536,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
@ -568,8 +568,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowFailedEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed events."""
@ -594,17 +594,17 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
yield workflow_finish_resp
yield self._base_task_pipeline._error_to_stream_response(err)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_stop_event(
self,
event: QueueStopEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle stop events."""
@ -644,13 +644,13 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
graph_runtime_state: GraphRuntimeState | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
self._ensure_graph_runtime_initialized(graph_runtime_state)
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer:
@ -740,10 +740,10 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: Any,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
) -> Generator[StreamResponse, None, None]:
"""Dispatch events using elegant pattern matching."""
handlers = self._get_event_handlers()
@ -782,15 +782,15 @@ class AdvancedChatAppGenerateTaskPipeline:
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 57-if-statement version.
"""
# Initialize graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None
graph_runtime_state: GraphRuntimeState | None = None
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
@ -835,7 +835,7 @@ class AdvancedChatAppGenerateTaskPipeline:
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None):
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
@ -846,7 +846,7 @@ class AdvancedChatAppGenerateTaskPipeline:
message.answer = answer_text
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [
MessageFile(
@ -902,9 +902,9 @@ class AdvancedChatAppGenerateTaskPipeline:
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
if self._base_task_pipeline.output_moderation_handler:
if self._base_task_pipeline.output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output()
self._base_task_pipeline.queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
@ -914,7 +914,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
return True
else:
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
self._base_task_pipeline.output_moderation_handler.append_new_token(text)
return False

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any, cast
from core.agent.entities import AgentEntity
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
@ -30,7 +30,7 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
Agent Chatbot App Config Entity.
"""
agent: Optional[AgentEntity] = None
agent: AgentEntity | None = None
class AgentChatAppConfigManager(BaseAppConfigManager):
@ -39,8 +39,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return filtered_config
@classmethod
def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_agent_mode_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate agent_mode and set defaults for agent feature
@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
if not isinstance(config["agent_mode"], dict):
agent_mode = config["agent_mode"]
if not isinstance(agent_mode, dict):
raise ValueError("agent_mode must be of object type")
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
config["agent_mode"]["enabled"] = False
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
agent_mode = cast(dict[str, Any], agent_mode)
if not isinstance(config["agent_mode"]["enabled"], bool):
if "enabled" not in agent_mode or not agent_mode["enabled"]:
agent_mode["enabled"] = False
if not isinstance(agent_mode["enabled"], bool):
raise ValueError("enabled in agent_mode must be of boolean type")
if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [
member.value for member in list(PlanningStrategy.__members__.values())
]:
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = []
if not agent_mode.get("tools"):
agent_mode["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list):
if not isinstance(agent_mode["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")
for tool in config["agent_mode"]["tools"]:
for tool in agent_mode["tools"]:
key = list(tool.keys())[0]
if key in OLD_TOOLS:
# old style, use tool name as key

View File

@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union, final
from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
@ -25,7 +25,7 @@ class BaseAppGenerator:
def _prepare_user_inputs(
self,
*,
user_inputs: Optional[Mapping[str, Any]],
user_inputs: Mapping[str, Any] | None,
variables: Sequence["VariableEntity"],
tenant_id: str,
strict_type_validation: bool = False,

View File

@ -2,7 +2,7 @@ import queue
import time
from abc import abstractmethod
from enum import IntEnum, auto
from typing import Any, Optional
from typing import Any
from sqlalchemy.orm import DeclarativeMeta
@ -32,6 +32,7 @@ class AppQueueManager:
self._task_id = task_id
self._user_id = user_id
self._invoke_from = invoke_from
self.invoke_from = invoke_from # Public accessor for invoke_from
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex(
@ -115,7 +116,7 @@ class AppQueueManager:
Set task stop flag
:return:
"""
result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
result: Any | None = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None:
return

View File

@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -82,11 +82,11 @@ class AppRunner:
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
query: str | None = None,
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
"""
Organize prompt messages
:param context:
@ -161,7 +161,7 @@ class AppRunner:
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None,
usage: LLMUsage | None = None,
):
"""
Direct output
@ -375,7 +375,7 @@ class AppRunner:
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record

View File

@ -1,5 +1,3 @@
from typing import Optional
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@ -32,8 +30,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
) -> ChatAppConfig:
"""
Convert app model config to chat app config

View File

@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -1,7 +1,7 @@
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union
from typing import Any, Union
from sqlalchemy.orm import Session
@ -135,7 +135,7 @@ class WorkflowResponseConverter:
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]:
) -> NodeStartStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
@ -190,7 +190,7 @@ class WorkflowResponseConverter:
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
) -> NodeFinishStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
@ -235,7 +235,7 @@ class WorkflowResponseConverter:
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:

View File

@ -1,5 +1,3 @@
from typing import Optional
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@ -24,7 +22,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
) -> CompletionAppConfig:
"""
Convert app model config to completion app config

View File

@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
if not app_model_config:
raise ValueError("Message app_model_config is None")
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params")

View File

@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -84,7 +84,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
raise e
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
if conversation:
stmt = select(AppModelConfig).where(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
@ -112,7 +112,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
conversation: Optional[Conversation] = None,
conversation: Conversation | None = None,
) -> tuple[Conversation, Message]:
"""
Initialize generate records

View File

@ -425,6 +425,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
context: contextvars.Context,
variable_loader: VariableLoader,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(

View File

@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())
return blocking_response.model_dump()
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import Optional, Union
from typing import Union
from sqlalchemy.orm import Session
@ -133,7 +133,7 @@ class WorkflowAppGenerateTaskPipeline:
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
self._invoke_from = queue_manager._invoke_from
self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -142,7 +142,7 @@ class WorkflowAppGenerateTaskPipeline:
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -202,7 +202,7 @@ class WorkflowAppGenerateTaskPipeline:
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
@ -264,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
@ -272,12 +272,12 @@ class WorkflowAppGenerateTaskPipeline:
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
err = self._base_task_pipeline.handle_error(event=event)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, **kwargs
@ -442,8 +442,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
@ -476,8 +476,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
@ -511,8 +511,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed and stop events."""
@ -549,8 +549,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueTextChunkEvent,
*,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
@ -601,10 +601,10 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: AppQueueEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
) -> Generator[StreamResponse, None, None]:
"""Dispatch events using elegant pattern matching."""
handlers = self._get_event_handlers()
@ -654,8 +654,8 @@ class WorkflowAppGenerateTaskPipeline:
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response using elegant Fluent Python patterns.
@ -722,7 +722,7 @@ class WorkflowAppGenerateTaskPipeline:
session.commit()
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
self, text: str, from_variable_selector: list[str] | None = None
) -> TextChunkStreamResponse:
"""
Handle completed event.

View File

@ -99,8 +99,8 @@ class AppGenerateEntity(BaseModel):
task_id: str
# app config
app_config: Any
file_upload_config: Optional[FileUploadConfig] = None
app_config: Any = None
file_upload_config: FileUploadConfig | None = None
inputs: Mapping[str, Any]
files: Sequence[File]
@ -126,10 +126,10 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: EasyUIBasedAppConfig
app_config: EasyUIBasedAppConfig = None # type: ignore
model_conf: ModelConfigWithCredentialsEntity
query: Optional[str] = None
query: str | None = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@ -140,8 +140,8 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
Base entity for conversation-based app generation.
"""
conversation_id: Optional[str] = None
parent_message_id: Optional[str] = Field(
conversation_id: str | None = None
parent_message_id: str | None = Field(
default=None,
description=(
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
@ -189,9 +189,9 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_run_id: Optional[str] = None
workflow_run_id: str | None = None
query: str
class SingleIterationRunEntity(BaseModel):
@ -202,7 +202,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
node_id: str
inputs: Mapping
single_iteration_run: Optional[SingleIterationRunEntity] = None
single_iteration_run: SingleIterationRunEntity | None = None
class SingleLoopRunEntity(BaseModel):
"""
@ -212,7 +212,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
node_id: str
inputs: Mapping
single_loop_run: Optional[SingleLoopRunEntity] = None
single_loop_run: SingleLoopRunEntity | None = None
class WorkflowAppGenerateEntity(AppGenerateEntity):
@ -221,7 +221,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_execution_id: str
class SingleIterationRunEntity(BaseModel):
@ -232,7 +232,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
single_iteration_run: SingleIterationRunEntity | None = None
class SingleLoopRunEntity(BaseModel):
"""

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from enum import StrEnum, auto
from typing import Any, Optional
from pydantic import BaseModel
@ -79,9 +79,9 @@ class QueueIterationStartEvent(AppQueueEvent):
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
predecessor_node_id: str | None = None
metadata: Mapping[str, Any] | None = None
class QueueIterationNextEvent(AppQueueEvent):
@ -114,12 +114,12 @@ class QueueIterationCompletedEvent(AppQueueEvent):
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
error: Optional[str] = None
error: str | None = None
class QueueLoopStartEvent(AppQueueEvent):
@ -132,20 +132,20 @@ class QueueLoopStartEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: Optional[str] = None
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
predecessor_node_id: str | None = None
metadata: Mapping[str, Any] | None = None
class QueueLoopNextEvent(AppQueueEvent):
@ -160,15 +160,15 @@ class QueueLoopNextEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: Optional[str] = None
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current loop
@ -187,21 +187,21 @@ class QueueLoopCompletedEvent(AppQueueEvent):
node_title: str
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
error: Optional[str] = None
error: str | None = None
class QueueTextChunkEvent(AppQueueEvent):
@ -211,11 +211,11 @@ class QueueTextChunkEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
from_variable_selector: Optional[list[str]] = None
from_variable_selector: list[str] | None = None
"""from variable selector"""
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
@ -252,9 +252,9 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: Sequence[RetrievalSourceMetadata]
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
@ -273,7 +273,7 @@ class QueueMessageEndEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.MESSAGE_END
llm_result: Optional[LLMResult] = None
llm_result: LLMResult | None = None
class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
@ -299,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: Optional[dict[str, Any]] = None
outputs: dict[str, Any] | None = None
class QueueWorkflowFailedEvent(AppQueueEvent):
@ -319,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int
outputs: Optional[dict[str, Any]] = None
outputs: dict[str, Any] | None = None
class QueueNodeStartedEvent(AppQueueEvent):
@ -362,22 +362,22 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_type: NodeType
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: Optional[str] = None
@ -391,11 +391,11 @@ class QueueAgentLogEvent(AppQueueEvent):
id: str
label: str
node_execution_id: str
parent_id: str | None
error: str | None
parent_id: str | None = None
error: str | None = None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
metadata: Mapping[str, Any] | None = None
node_id: str
@ -404,10 +404,10 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
event: QueueEvent = QueueEvent.RETRY
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
retry_index: int # retry index
@ -425,22 +425,22 @@ class QueueNodeExceptionEvent(AppQueueEvent):
node_type: NodeType
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
@ -458,14 +458,14 @@ class QueueNodeFailedEvent(AppQueueEvent):
parallel_id: Optional[str] = None
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
@ -494,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.ERROR
error: Optional[Any] = None
error: Any | None = None
class QueuePingEvent(AppQueueEvent):
@ -510,15 +510,15 @@ class QueueStopEvent(AppQueueEvent):
QueueStopEvent entity
"""
class StopBy(Enum):
class StopBy(StrEnum):
"""
Stop by enum
"""
USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
INPUT_MODERATION = "input-moderation"
USER_MANUAL = auto()
ANNOTATION_REPLY = auto()
OUTPUT_MODERATION = auto()
INPUT_MODERATION = auto()
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy

View File

@ -1,11 +1,10 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from enum import StrEnum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -51,7 +50,7 @@ class WorkflowTaskState(TaskState):
answer: str = ""
class StreamEvent(Enum):
class StreamEvent(StrEnum):
"""
Stream event
"""
@ -90,9 +89,6 @@ class StreamResponse(BaseModel):
event: StreamEvent
task_id: str
def to_dict(self):
return jsonable_encoder(self)
class ErrorStreamResponse(StreamResponse):
"""
@ -112,7 +108,7 @@ class MessageStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE
id: str
answer: str
from_variable_selector: Optional[list[str]] = None
from_variable_selector: list[str] | None = None
class MessageAudioStreamResponse(StreamResponse):
@ -141,7 +137,7 @@ class MessageEndStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
metadata: dict = Field(default_factory=dict)
files: Optional[Sequence[Mapping[str, Any]]] = None
files: Sequence[Mapping[str, Any]] | None = None
class MessageFileStreamResponse(StreamResponse):
@ -174,12 +170,12 @@ class AgentThoughtStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.AGENT_THOUGHT
id: str
position: int
thought: Optional[str] = None
observation: Optional[str] = None
tool: Optional[str] = None
tool_labels: Optional[dict] = None
tool_input: Optional[str] = None
message_files: Optional[list[str]] = None
thought: str | None = None
observation: str | None = None
tool: str | None = None
tool_labels: dict | None = None
tool_input: str | None = None
message_files: list[str] | None = None
class AgentMessageStreamResponse(StreamResponse):
@ -225,16 +221,16 @@ class WorkflowFinishStreamResponse(StreamResponse):
id: str
workflow_id: str
status: str
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
created_by: Optional[dict] = None
created_by: dict | None = None
created_at: int
finished_at: int
exceptions_count: Optional[int] = 0
files: Optional[Sequence[Mapping[str, Any]]] = []
exceptions_count: int | None = 0
files: Sequence[Mapping[str, Any]] | None = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
workflow_run_id: str
@ -261,14 +257,14 @@ class NodeStartStreamResponse(StreamResponse):
inputs_truncated: bool = False
created_at: int
extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
parallel_run_id: Optional[str] = None
agent_strategy: Optional[AgentNodeStrategyInit] = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
parallel_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@ -322,18 +318,18 @@ class NodeFinishStreamResponse(StreamResponse):
outputs: Optional[Mapping[str, Any]] = None
outputs_truncated: bool = True
status: str
error: Optional[str] = None
error: str | None = None
elapsed_time: float
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
@ -394,18 +390,18 @@ class NodeRetryStreamResponse(StreamResponse):
outputs: Optional[Mapping[str, Any]] = None
outputs_truncated: bool = False
status: str
error: Optional[str] = None
error: str | None = None
elapsed_time: float
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
retry_index: int = 0
event: StreamEvent = StreamEvent.NODE_RETRY
@ -514,10 +510,10 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
inputs: Optional[Mapping] = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
error: str | None = None
elapsed_time: float
total_tokens: int
execution_metadata: Optional[Mapping] = None
execution_metadata: Mapping | None = None
finished_at: int
steps: int
@ -569,7 +565,7 @@ class LoopNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
pre_loop_output: Optional[Any] = None
pre_loop_output: Any | None = None
extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
@ -601,14 +597,14 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
inputs: Optional[Mapping] = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
error: str | None = None
elapsed_time: float
total_tokens: int
execution_metadata: Optional[Mapping] = None
execution_metadata: Mapping | None = None
finished_at: int
steps: int
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_COMPLETED
workflow_run_id: str
@ -626,7 +622,7 @@ class TextChunkStreamResponse(StreamResponse):
"""
text: str
from_variable_selector: Optional[list[str]] = None
from_variable_selector: list[str] | None = None
event: StreamEvent = StreamEvent.TEXT_CHUNK
data: Data
@ -688,7 +684,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
WorkflowAppStreamResponse entity
"""
workflow_run_id: Optional[str] = None
workflow_run_id: str | None = None
class AppBlockingResponse(BaseModel):
@ -698,9 +694,6 @@ class AppBlockingResponse(BaseModel):
task_id: str
def to_dict(self):
return jsonable_encoder(self)
class ChatbotAppBlockingResponse(AppBlockingResponse):
"""
@ -756,8 +749,8 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
id: str
workflow_id: str
status: str
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
@ -781,11 +774,11 @@ class AgentLogStreamResponse(StreamResponse):
node_execution_id: str
id: str
label: str
parent_id: str | None
error: str | None
parent_id: str | None = None
error: str | None = None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
metadata: Mapping[str, Any] | None = None
node_id: str
event: StreamEvent = StreamEvent.AGENT_LOG

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from sqlalchemy import select
@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
class AnnotationReplyFeature:
def query(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record
@ -35,6 +34,9 @@ class AnnotationReplyFeature:
collection_binding_detail = annotation_setting.collection_binding_detail
if not collection_binding_detail:
return None
try:
score_threshold = annotation_setting.score_threshold or 1
embedding_provider_name = collection_binding_detail.provider_name

View File

@ -1 +1,3 @@
from .rate_limit import RateLimit
__all__ = ["RateLimit"]

View File

@ -3,7 +3,7 @@ import time
import uuid
from collections.abc import Generator, Mapping
from datetime import timedelta
from typing import Any, Optional, Union
from typing import Any, Union
from core.errors.error import AppInvokeQuotaExceededError
from extensions.ext_redis import redis_client
@ -19,7 +19,7 @@ class RateLimit:
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {}
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
def __new__(cls, client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
@ -63,7 +63,7 @@ class RateLimit:
if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests)
def enter(self, request_id: Optional[str] = None) -> str:
def enter(self, request_id: str | None = None) -> str:
if self.disabled():
return RateLimit._UNLIMITED_REQUEST_ID
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:

View File

@ -1,6 +1,5 @@
import logging
import time
from typing import Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -38,11 +37,11 @@ class BasedGenerateTaskPipeline:
):
self._application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
self.start_at = time.perf_counter()
self.output_moderation_handler = self._init_output_moderation()
self.stream = stream
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
logger.debug("error: %s", event.error)
e = event.error
err: Exception
@ -86,7 +85,7 @@ class BasedGenerateTaskPipeline:
return message
def _error_to_stream_response(self, e: Exception):
def error_to_stream_response(self, e: Exception):
"""
Error to stream response.
:param e: exception
@ -94,14 +93,14 @@ class BasedGenerateTaskPipeline:
"""
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
def _ping_stream_response(self) -> PingStreamResponse:
def ping_stream_response(self) -> PingStreamResponse:
"""
Ping stream response.
:return:
"""
return PingStreamResponse(task_id=self._application_generate_entity.task_id)
def _init_output_moderation(self) -> Optional[OutputModeration]:
def _init_output_moderation(self) -> OutputModeration | None:
"""
Init output moderation.
:return:
@ -118,21 +117,21 @@ class BasedGenerateTaskPipeline:
)
return None
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
def handle_output_moderation_when_task_finished(self, completion: str) -> str | None:
"""
Handle output moderation when task finished.
:param completion: completion
:return:
"""
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
completion, flagged = self._output_moderation_handler.moderation_completion(
completion, flagged = self.output_moderation_handler.moderation_completion(
completion=completion, public_event=False
)
self._output_moderation_handler = None
self.output_moderation_handler = None
if flagged:
return completion

View File

@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Optional, Union, cast
from typing import Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -109,7 +109,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
task_state=self._task_state,
)
self._conversation_name_generate_thread: Optional[Thread] = None
self._conversation_name_generate_thread: Thread | None = None
def process(
self,
@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value:
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data(
@ -209,7 +209,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
@ -252,7 +252,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
self, publisher: AppGeneratorTTSPublisher | None, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self._error_to_stream_response(err)
yield self.error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._handle_stop(event)
# handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished(
output_moderation_answer = self.handle_output_moderation_when_task_finished(
cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self.ping_stream_response()
else:
continue
if publisher:
@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None):
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
"""
Save message.
:return:
@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.provider_response_latency = time.perf_counter() - self._start_at
message.provider_response_latency = time.perf_counter() - self.start_at
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
# transform usage
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
@ -466,14 +466,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
)
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> AgentThoughtStreamResponse | None:
"""
Agent thought to stream response.
:param event: agent thought event
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: Optional[MessageAgentThought] = (
agent_thought: MessageAgentThought | None = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
if self.output_moderation_handler:
if self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
self.queue_manager.publish(
QueueLLMChunkEvent(
chunk=LLMResultChunk(
@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
return True
else:
self._output_moderation_handler.append_new_token(text)
self.output_moderation_handler.append_new_token(text)
return False

View File

@ -1,6 +1,6 @@
import logging
from threading import Thread
from typing import Optional, Union
from typing import Union
from flask import Flask, current_app
from sqlalchemy import select
@ -52,7 +52,7 @@ class MessageCycleManager:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
"""
Generate conversation name.
:param conversation_id: conversation id
@ -92,7 +92,7 @@ class MessageCycleManager:
if not conversation:
return
if conversation.mode != AppMode.COMPLETION.value:
if conversation.mode != AppMode.COMPLETION:
app_model = conversation.app
if not app_model:
return
@ -111,7 +111,7 @@ class MessageCycleManager:
db.session.commit()
db.session.close()
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> MessageAnnotation | None:
"""
Handle annotation reply.
:param event: event
@ -141,7 +141,7 @@ class MessageCycleManager:
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata.retriever_resources = event.retriever_resources
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
"""
Message file to stream response.
:param event: event
@ -180,7 +180,7 @@ class MessageCycleManager:
return None
def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
) -> MessageStreamResponse:
"""
Message to stream response.