mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
Merge branch 'feat/rag-2' into feat/add-dataset-service-api-enable
# Conflicts: # api/controllers/console/datasets/datasets.py # api/controllers/service_api/wraps.py # api/services/dataset_service.py
This commit is contained in:
@ -1,12 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
|
||||
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
|
||||
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
||||
if not sensitive_word_avoidance_dict:
|
||||
return None
|
||||
@ -21,7 +19,7 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id, config: dict, only_structure_validate: bool = False
|
||||
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
@ -38,7 +36,14 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
|
||||
if not only_structure_validate:
|
||||
typ = config["sensitive_word_avoidance"]["type"]
|
||||
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
|
||||
if not isinstance(typ, str):
|
||||
raise ValueError("sensitive_word_avoidance.type must be a string")
|
||||
|
||||
sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config")
|
||||
if sensitive_word_avoidance_config is None:
|
||||
sensitive_word_avoidance_config = {}
|
||||
if not isinstance(sensitive_word_avoidance_config, dict):
|
||||
raise ValueError("sensitive_word_avoidance.config must be a dict")
|
||||
|
||||
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
|
||||
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
|
||||
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
|
||||
|
||||
class AgentConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> Optional[AgentEntity]:
|
||||
def convert(cls, config: dict) -> AgentEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@ -14,7 +13,7 @@ from services.dataset_service import DatasetService
|
||||
|
||||
class DatasetConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> Optional[DatasetEntity]:
|
||||
def convert(cls, config: dict) -> DatasetEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
|
||||
@ -25,10 +25,14 @@ class PromptTemplateConfigManager:
|
||||
if chat_prompt_config:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
text = message.get("text")
|
||||
if not isinstance(text, str):
|
||||
raise ValueError("message text must be a string")
|
||||
role = message.get("role")
|
||||
if not isinstance(role, str):
|
||||
raise ValueError("message role must be a string")
|
||||
chat_prompt_messages.append(
|
||||
AdvancedChatMessageEntity(
|
||||
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
||||
)
|
||||
AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role))
|
||||
)
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
||||
@ -66,7 +70,7 @@ class PromptTemplateConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("prompt_type"):
|
||||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
|
||||
|
||||
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
|
||||
if config["prompt_type"] not in prompt_type_vals:
|
||||
@ -86,7 +90,7 @@ class PromptTemplateConfigManager:
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED:
|
||||
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
|
||||
raise ValueError(
|
||||
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Literal, Optional
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@ -17,7 +17,7 @@ class ModelConfigEntity(BaseModel):
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
mode: Optional[str] = None
|
||||
mode: str | None = None
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
stop: list[str] = Field(default_factory=list)
|
||||
|
||||
@ -53,7 +53,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||
assistant: str
|
||||
|
||||
prompt: str
|
||||
role_prefix: Optional[RolePrefixEntity] = None
|
||||
role_prefix: RolePrefixEntity | None = None
|
||||
|
||||
|
||||
class PromptTemplateEntity(BaseModel):
|
||||
@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel):
|
||||
Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class PromptType(Enum):
|
||||
class PromptType(StrEnum):
|
||||
"""
|
||||
Prompt Type.
|
||||
'simple', 'advanced'
|
||||
"""
|
||||
|
||||
SIMPLE = "simple"
|
||||
ADVANCED = "advanced"
|
||||
SIMPLE = auto()
|
||||
ADVANCED = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
@ -84,9 +84,9 @@ class PromptTemplateEntity(BaseModel):
|
||||
raise ValueError(f"invalid prompt type value {value}")
|
||||
|
||||
prompt_type: PromptType
|
||||
simple_prompt_template: Optional[str] = None
|
||||
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
|
||||
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
||||
simple_prompt_template: str | None = None
|
||||
advanced_chat_prompt_template: AdvancedChatPromptTemplateEntity | None = None
|
||||
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
|
||||
|
||||
|
||||
class VariableEntityType(StrEnum):
|
||||
@ -112,7 +112,7 @@ class VariableEntity(BaseModel):
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
hide: bool = False
|
||||
max_length: Optional[int] = None
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
|
||||
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
|
||||
@ -183,7 +183,7 @@ class ModelConfig(BaseModel):
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Conditon detail
|
||||
Condition detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
@ -196,8 +196,8 @@ class MetadataFilteringCondition(BaseModel):
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
@ -205,14 +205,14 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
Dataset Retrieve Config Entity.
|
||||
"""
|
||||
|
||||
class RetrieveStrategy(Enum):
|
||||
class RetrieveStrategy(StrEnum):
|
||||
"""
|
||||
Dataset Retrieve Strategy.
|
||||
'single' or 'multiple'
|
||||
"""
|
||||
|
||||
SINGLE = "single"
|
||||
MULTIPLE = "multiple"
|
||||
SINGLE = auto()
|
||||
MULTIPLE = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
@ -227,18 +227,18 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
return mode
|
||||
raise ValueError(f"invalid retrieve strategy value {value}")
|
||||
|
||||
query_variable: Optional[str] = None # Only when app mode is completion
|
||||
query_variable: str | None = None # Only when app mode is completion
|
||||
|
||||
retrieve_strategy: RetrieveStrategy
|
||||
top_k: Optional[int] = None
|
||||
score_threshold: Optional[float] = 0.0
|
||||
rerank_mode: Optional[str] = "reranking_model"
|
||||
reranking_model: Optional[dict] = None
|
||||
weights: Optional[dict] = None
|
||||
reranking_enabled: Optional[bool] = True
|
||||
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
top_k: int | None = None
|
||||
score_threshold: float | None = 0.0
|
||||
rerank_mode: str | None = "reranking_model"
|
||||
reranking_model: dict | None = None
|
||||
weights: dict | None = None
|
||||
reranking_enabled: bool | None = True
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
|
||||
metadata_model_config: ModelConfig | None = None
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = None
|
||||
|
||||
|
||||
class DatasetEntity(BaseModel):
|
||||
@ -265,8 +265,8 @@ class TextToSpeechEntity(BaseModel):
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
voice: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
voice: str | None = None
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class TracingConfigEntity(BaseModel):
|
||||
@ -279,15 +279,15 @@ class TracingConfigEntity(BaseModel):
|
||||
|
||||
|
||||
class AppAdditionalFeatures(BaseModel):
|
||||
file_upload: Optional[FileUploadConfig] = None
|
||||
opening_statement: Optional[str] = None
|
||||
file_upload: FileUploadConfig | None = None
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: list[str] = []
|
||||
suggested_questions_after_answer: bool = False
|
||||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: Optional[TextToSpeechEntity] = None
|
||||
trace_config: Optional[TracingConfigEntity] = None
|
||||
text_to_speech: TextToSpeechEntity | None = None
|
||||
trace_config: TracingConfigEntity | None = None
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
@ -300,15 +300,15 @@ class AppConfig(BaseModel):
|
||||
app_mode: AppMode
|
||||
additional_features: Optional[AppAdditionalFeatures] = None
|
||||
variables: list[VariableEntity] = []
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
|
||||
|
||||
|
||||
class EasyUIBasedAppModelConfigFrom(Enum):
|
||||
class EasyUIBasedAppModelConfigFrom(StrEnum):
|
||||
"""
|
||||
App Model Config From.
|
||||
"""
|
||||
|
||||
ARGS = "args"
|
||||
ARGS = auto()
|
||||
APP_LATEST_CONFIG = "app-latest-config"
|
||||
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
|
||||
|
||||
@ -323,7 +323,7 @@ class EasyUIBasedAppConfig(AppConfig):
|
||||
app_model_config_dict: dict
|
||||
model: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
dataset: Optional[DatasetEntity] = None
|
||||
dataset: DatasetEntity | None = None
|
||||
external_data_variables: list[ExternalDataVariableEntity] = []
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -390,7 +390,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
conversation: Optional[Conversation] = None,
|
||||
conversation: Conversation | None = None,
|
||||
stream: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -249,7 +249,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
def query_app_annotations_to_reply(
|
||||
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
|
||||
) -> Optional[MessageAnnotation]:
|
||||
) -> MessageAnnotation | None:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
|
||||
@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, Any] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, Any] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -4,7 +4,7 @@ import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from threading import Thread
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -169,7 +169,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
if self._base_task_pipeline._stream:
|
||||
if self._base_task_pipeline.stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
@ -228,7 +228,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self, trace_manager: TraceQueueManager | None = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
@ -289,7 +289,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
|
||||
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
|
||||
"""Fluent validation for graph runtime state."""
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
@ -297,13 +297,13 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||
"""Handle ping events."""
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
yield self._base_task_pipeline.ping_stream_response()
|
||||
|
||||
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
|
||||
"""Handle error events."""
|
||||
with self._database_session() as session:
|
||||
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
@ -404,8 +404,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueTextChunkEvent,
|
||||
*,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle text chunk events."""
|
||||
@ -505,8 +505,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowSucceededEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow succeeded events."""
|
||||
@ -536,8 +536,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowPartialSuccessEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow partial success events."""
|
||||
@ -568,8 +568,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowFailedEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow failed events."""
|
||||
@ -594,17 +594,17 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
|
||||
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||
|
||||
def _handle_stop_event(
|
||||
self,
|
||||
event: QueueStopEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle stop events."""
|
||||
@ -644,13 +644,13 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueAdvancedChatMessageEndEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
|
||||
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
|
||||
self._task_state.answer
|
||||
)
|
||||
if output_moderation_answer:
|
||||
@ -740,10 +740,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: Any,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Dispatch events using elegant pattern matching."""
|
||||
handlers = self._get_event_handlers()
|
||||
@ -782,15 +782,15 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response using elegant Fluent Python patterns.
|
||||
Maintains exact same functionality as original 57-if-statement version.
|
||||
"""
|
||||
# Initialize graph runtime state
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
graph_runtime_state: GraphRuntimeState | None = None
|
||||
|
||||
for queue_message in self._base_task_pipeline.queue_manager.listen():
|
||||
event = queue_message.event
|
||||
@ -835,7 +835,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None):
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
message = self._get_message(session=session)
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
@ -846,7 +846,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
message.answer = answer_text
|
||||
message.updated_at = naive_utc_now()
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
message_files = [
|
||||
MessageFile(
|
||||
@ -902,9 +902,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self._base_task_pipeline._output_moderation_handler:
|
||||
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
||||
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
||||
if self._base_task_pipeline.output_moderation_handler:
|
||||
if self._base_task_pipeline.output_moderation_handler.should_direct_output():
|
||||
self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output()
|
||||
self._base_task_pipeline.queue_manager.publish(
|
||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
@ -914,7 +914,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
|
||||
self._base_task_pipeline.output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any, cast
|
||||
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
@ -30,7 +30,7 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
|
||||
Agent Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
agent: Optional[AgentEntity] = None
|
||||
agent: AgentEntity | None = None
|
||||
|
||||
|
||||
class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
@ -39,8 +39,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
cls,
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
) -> AgentChatAppConfig:
|
||||
"""
|
||||
Convert app model config to agent chat app config
|
||||
@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
return filtered_config
|
||||
|
||||
@classmethod
|
||||
def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_agent_mode_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate agent_mode and set defaults for agent feature
|
||||
|
||||
@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
agent_mode = config["agent_mode"]
|
||||
if not isinstance(agent_mode, dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
|
||||
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
|
||||
config["agent_mode"]["enabled"] = False
|
||||
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
|
||||
agent_mode = cast(dict[str, Any], agent_mode)
|
||||
|
||||
if not isinstance(config["agent_mode"]["enabled"], bool):
|
||||
if "enabled" not in agent_mode or not agent_mode["enabled"]:
|
||||
agent_mode["enabled"] = False
|
||||
|
||||
if not isinstance(agent_mode["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
if not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
if not agent_mode.get("strategy"):
|
||||
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
if config["agent_mode"]["strategy"] not in [
|
||||
member.value for member in list(PlanningStrategy.__members__.values())
|
||||
]:
|
||||
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
|
||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||
|
||||
if not config["agent_mode"].get("tools"):
|
||||
config["agent_mode"]["tools"] = []
|
||||
if not agent_mode.get("tools"):
|
||||
agent_mode["tools"] = []
|
||||
|
||||
if not isinstance(config["agent_mode"]["tools"], list):
|
||||
if not isinstance(agent_mode["tools"], list):
|
||||
raise ValueError("tools in agent_mode must be a list of objects")
|
||||
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
for tool in agent_mode["tools"]:
|
||||
key = list(tool.keys())[0]
|
||||
if key in OLD_TOOLS:
|
||||
# old style, use tool name as key
|
||||
|
||||
@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
else:
|
||||
response["metadata"] = {}
|
||||
|
||||
return response
|
||||
|
||||
@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, final
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -25,7 +25,7 @@ class BaseAppGenerator:
|
||||
def _prepare_user_inputs(
|
||||
self,
|
||||
*,
|
||||
user_inputs: Optional[Mapping[str, Any]],
|
||||
user_inputs: Mapping[str, Any] | None,
|
||||
variables: Sequence["VariableEntity"],
|
||||
tenant_id: str,
|
||||
strict_type_validation: bool = False,
|
||||
|
||||
@ -2,7 +2,7 @@ import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
@ -32,6 +32,7 @@ class AppQueueManager:
|
||||
self._task_id = task_id
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
self.invoke_from = invoke_from # Public accessor for invoke_from
|
||||
|
||||
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
redis_client.setex(
|
||||
@ -115,7 +116,7 @@ class AppQueueManager:
|
||||
Set task stop flag
|
||||
:return:
|
||||
"""
|
||||
result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
result: Any | None = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
if result is None:
|
||||
return
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -82,11 +82,11 @@ class AppRunner:
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
query: str | None = None,
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
@ -161,7 +161,7 @@ class AppRunner:
|
||||
prompt_messages: list,
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None,
|
||||
usage: LLMUsage | None = None,
|
||||
):
|
||||
"""
|
||||
Direct output
|
||||
@ -375,7 +375,7 @@ class AppRunner:
|
||||
|
||||
def query_app_annotations_to_reply(
|
||||
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
|
||||
) -> Optional[MessageAnnotation]:
|
||||
) -> MessageAnnotation | None:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
@ -32,8 +30,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
cls,
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
) -> ChatAppConfig:
|
||||
"""
|
||||
Convert app model config to chat app config
|
||||
|
||||
@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
else:
|
||||
response["metadata"] = {}
|
||||
|
||||
return response
|
||||
|
||||
@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -135,7 +135,7 @@ class WorkflowResponseConverter:
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
) -> NodeStartStreamResponse | None:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
@ -190,7 +190,7 @@ class WorkflowResponseConverter:
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
) -> NodeFinishStreamResponse | None:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
@ -235,7 +235,7 @@ class WorkflowResponseConverter:
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
@ -24,7 +22,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
|
||||
class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
|
||||
) -> CompletionAppConfig:
|
||||
"""
|
||||
Convert app model config to completion app config
|
||||
|
||||
@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
raise MoreLikeThisDisabledError()
|
||||
|
||||
app_model_config = message.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("Message app_model_config is None")
|
||||
override_model_config_dict = app_model_config.to_dict()
|
||||
model_dict = override_model_config_dict["model"]
|
||||
completion_params = model_dict.get("completion_params")
|
||||
|
||||
@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
else:
|
||||
response["metadata"] = {}
|
||||
|
||||
return response
|
||||
|
||||
@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -84,7 +84,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
|
||||
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
|
||||
if conversation:
|
||||
stmt = select(AppModelConfig).where(
|
||||
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
|
||||
@ -112,7 +112,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
conversation: Optional[Conversation] = None,
|
||||
conversation: Conversation | None = None,
|
||||
) -> tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
|
||||
@ -425,6 +425,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
|
||||
@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return dict(blocking_response.to_dict())
|
||||
return blocking_response.model_dump()
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||
@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, object] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, object] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -133,7 +133,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._workflow_run_id = ""
|
||||
self._invoke_from = queue_manager._invoke_from
|
||||
self._invoke_from = queue_manager.invoke_from
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@ -142,7 +142,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
:return:
|
||||
"""
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._base_task_pipeline._stream:
|
||||
if self._base_task_pipeline.stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
@ -202,7 +202,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self, trace_manager: TraceQueueManager | None = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
@ -264,7 +264,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
|
||||
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
|
||||
"""Fluent validation for graph runtime state."""
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
@ -272,12 +272,12 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||
"""Handle ping events."""
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
yield self._base_task_pipeline.ping_stream_response()
|
||||
|
||||
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
|
||||
"""Handle error events."""
|
||||
err = self._base_task_pipeline._handle_error(event=event)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
err = self._base_task_pipeline.handle_error(event=event)
|
||||
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(
|
||||
self, event: QueueWorkflowStartedEvent, **kwargs
|
||||
@ -442,8 +442,8 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowSucceededEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow succeeded events."""
|
||||
@ -476,8 +476,8 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowPartialSuccessEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow partial success events."""
|
||||
@ -511,8 +511,8 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow failed and stop events."""
|
||||
@ -549,8 +549,8 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueTextChunkEvent,
|
||||
*,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle text chunk events."""
|
||||
@ -601,10 +601,10 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: AppQueueEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Dispatch events using elegant pattern matching."""
|
||||
handlers = self._get_event_handlers()
|
||||
@ -654,8 +654,8 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response using elegant Fluent Python patterns.
|
||||
@ -722,7 +722,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
session.commit()
|
||||
|
||||
def _text_chunk_to_stream_response(
|
||||
self, text: str, from_variable_selector: Optional[list[str]] = None
|
||||
self, text: str, from_variable_selector: list[str] | None = None
|
||||
) -> TextChunkStreamResponse:
|
||||
"""
|
||||
Handle completed event.
|
||||
|
||||
@ -99,8 +99,8 @@ class AppGenerateEntity(BaseModel):
|
||||
task_id: str
|
||||
|
||||
# app config
|
||||
app_config: Any
|
||||
file_upload_config: Optional[FileUploadConfig] = None
|
||||
app_config: Any = None
|
||||
file_upload_config: FileUploadConfig | None = None
|
||||
|
||||
inputs: Mapping[str, Any]
|
||||
files: Sequence[File]
|
||||
@ -126,10 +126,10 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: EasyUIBasedAppConfig
|
||||
app_config: EasyUIBasedAppConfig = None # type: ignore
|
||||
model_conf: ModelConfigWithCredentialsEntity
|
||||
|
||||
query: Optional[str] = None
|
||||
query: str | None = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
@ -140,8 +140,8 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
|
||||
Base entity for conversation-based app generation.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = Field(
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
|
||||
@ -189,9 +189,9 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
app_config: WorkflowUIBasedAppConfig = None # type: ignore
|
||||
|
||||
workflow_run_id: Optional[str] = None
|
||||
workflow_run_id: str | None = None
|
||||
query: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
@ -202,7 +202,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
node_id: str
|
||||
inputs: Mapping
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
single_iteration_run: SingleIterationRunEntity | None = None
|
||||
|
||||
class SingleLoopRunEntity(BaseModel):
|
||||
"""
|
||||
@ -212,7 +212,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
node_id: str
|
||||
inputs: Mapping
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
|
||||
class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
@ -221,7 +221,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
app_config: WorkflowUIBasedAppConfig = None # type: ignore
|
||||
workflow_execution_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
@ -232,7 +232,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
single_iteration_run: SingleIterationRunEntity | None = None
|
||||
|
||||
class SingleLoopRunEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -79,9 +79,9 @@ class QueueIterationStartEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
predecessor_node_id: str | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
@ -114,12 +114,12 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class QueueLoopStartEvent(AppQueueEvent):
|
||||
@ -132,20 +132,20 @@ class QueueLoopStartEvent(AppQueueEvent):
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
predecessor_node_id: str | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
|
||||
|
||||
class QueueLoopNextEvent(AppQueueEvent):
|
||||
@ -160,15 +160,15 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current loop
|
||||
@ -187,21 +187,21 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
node_title: str
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
@ -211,11 +211,11 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.TEXT_CHUNK
|
||||
text: str
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
from_variable_selector: list[str] | None = None
|
||||
"""from variable selector"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
@ -252,9 +252,9 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata]
|
||||
in_iteration_id: Optional[str] = None
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
@ -273,7 +273,7 @@ class QueueMessageEndEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.MESSAGE_END
|
||||
llm_result: Optional[LLMResult] = None
|
||||
llm_result: LLMResult | None = None
|
||||
|
||||
|
||||
class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
|
||||
@ -299,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
outputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
@ -319,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
||||
exceptions_count: int
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
outputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class QueueNodeStartedEvent(AppQueueEvent):
|
||||
@ -362,22 +362,22 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
node_type: NodeType
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
@ -391,11 +391,11 @@ class QueueAgentLogEvent(AppQueueEvent):
|
||||
id: str
|
||||
label: str
|
||||
node_execution_id: str
|
||||
parent_id: str | None
|
||||
error: str | None
|
||||
parent_id: str | None = None
|
||||
error: str | None = None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
node_id: str
|
||||
|
||||
|
||||
@ -404,10 +404,10 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRY
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
@ -425,22 +425,22 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
node_type: NodeType
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
|
||||
@ -458,14 +458,14 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
parallel_id: Optional[str] = None
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
|
||||
error: str
|
||||
|
||||
@ -494,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ERROR
|
||||
error: Optional[Any] = None
|
||||
error: Any | None = None
|
||||
|
||||
|
||||
class QueuePingEvent(AppQueueEvent):
|
||||
@ -510,15 +510,15 @@ class QueueStopEvent(AppQueueEvent):
|
||||
QueueStopEvent entity
|
||||
"""
|
||||
|
||||
class StopBy(Enum):
|
||||
class StopBy(StrEnum):
|
||||
"""
|
||||
Stop by enum
|
||||
"""
|
||||
|
||||
USER_MANUAL = "user-manual"
|
||||
ANNOTATION_REPLY = "annotation-reply"
|
||||
OUTPUT_MODERATION = "output-moderation"
|
||||
INPUT_MODERATION = "input-moderation"
|
||||
USER_MANUAL = auto()
|
||||
ANNOTATION_REPLY = auto()
|
||||
OUTPUT_MODERATION = auto()
|
||||
INPUT_MODERATION = auto()
|
||||
|
||||
event: QueueEvent = QueueEvent.STOP
|
||||
stopped_by: StopBy
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@ -51,7 +50,7 @@ class WorkflowTaskState(TaskState):
|
||||
answer: str = ""
|
||||
|
||||
|
||||
class StreamEvent(Enum):
|
||||
class StreamEvent(StrEnum):
|
||||
"""
|
||||
Stream event
|
||||
"""
|
||||
@ -90,9 +89,6 @@ class StreamResponse(BaseModel):
|
||||
event: StreamEvent
|
||||
task_id: str
|
||||
|
||||
def to_dict(self):
|
||||
return jsonable_encoder(self)
|
||||
|
||||
|
||||
class ErrorStreamResponse(StreamResponse):
|
||||
"""
|
||||
@ -112,7 +108,7 @@ class MessageStreamResponse(StreamResponse):
|
||||
event: StreamEvent = StreamEvent.MESSAGE
|
||||
id: str
|
||||
answer: str
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
@ -141,7 +137,7 @@ class MessageEndStreamResponse(StreamResponse):
|
||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||
id: str
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = None
|
||||
files: Sequence[Mapping[str, Any]] | None = None
|
||||
|
||||
|
||||
class MessageFileStreamResponse(StreamResponse):
|
||||
@ -174,12 +170,12 @@ class AgentThoughtStreamResponse(StreamResponse):
|
||||
event: StreamEvent = StreamEvent.AGENT_THOUGHT
|
||||
id: str
|
||||
position: int
|
||||
thought: Optional[str] = None
|
||||
observation: Optional[str] = None
|
||||
tool: Optional[str] = None
|
||||
tool_labels: Optional[dict] = None
|
||||
tool_input: Optional[str] = None
|
||||
message_files: Optional[list[str]] = None
|
||||
thought: str | None = None
|
||||
observation: str | None = None
|
||||
tool: str | None = None
|
||||
tool_labels: dict | None = None
|
||||
tool_input: str | None = None
|
||||
message_files: list[str] | None = None
|
||||
|
||||
|
||||
class AgentMessageStreamResponse(StreamResponse):
|
||||
@ -225,16 +221,16 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
created_by: Optional[dict] = None
|
||||
created_by: dict | None = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
exceptions_count: Optional[int] = 0
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
exceptions_count: int | None = 0
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
||||
workflow_run_id: str
|
||||
@ -261,14 +257,14 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
inputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
parallel_run_id: Optional[str] = None
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
parallel_run_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
@ -322,18 +318,18 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
outputs_truncated: bool = True
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
@ -394,18 +390,18 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
outputs_truncated: bool = False
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
@ -514,10 +510,10 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
inputs: Optional[Mapping] = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
execution_metadata: Optional[Mapping] = None
|
||||
execution_metadata: Mapping | None = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
|
||||
@ -569,7 +565,7 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_loop_output: Optional[Any] = None
|
||||
pre_loop_output: Any | None = None
|
||||
extras: dict = Field(default_factory=dict)
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
@ -601,14 +597,14 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
inputs: Optional[Mapping] = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
execution_metadata: Optional[Mapping] = None
|
||||
execution_metadata: Mapping | None = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_COMPLETED
|
||||
workflow_run_id: str
|
||||
@ -626,7 +622,7 @@ class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
||||
text: str
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_CHUNK
|
||||
data: Data
|
||||
@ -688,7 +684,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
|
||||
WorkflowAppStreamResponse entity
|
||||
"""
|
||||
|
||||
workflow_run_id: Optional[str] = None
|
||||
workflow_run_id: str | None = None
|
||||
|
||||
|
||||
class AppBlockingResponse(BaseModel):
|
||||
@ -698,9 +694,6 @@ class AppBlockingResponse(BaseModel):
|
||||
|
||||
task_id: str
|
||||
|
||||
def to_dict(self):
|
||||
return jsonable_encoder(self)
|
||||
|
||||
|
||||
class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
@ -756,8 +749,8 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
@ -781,11 +774,11 @@ class AgentLogStreamResponse(StreamResponse):
|
||||
node_execution_id: str
|
||||
id: str
|
||||
label: str
|
||||
parent_id: str | None
|
||||
error: str | None
|
||||
parent_id: str | None = None
|
||||
error: str | None = None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
metadata: Mapping[str, Any] | None = None
|
||||
node_id: str
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
class AnnotationReplyFeature:
|
||||
def query(
|
||||
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
|
||||
) -> Optional[MessageAnnotation]:
|
||||
) -> MessageAnnotation | None:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
@ -35,6 +34,9 @@ class AnnotationReplyFeature:
|
||||
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
|
||||
if not collection_binding_detail:
|
||||
return None
|
||||
|
||||
try:
|
||||
score_threshold = annotation_setting.score_threshold or 1
|
||||
embedding_provider_name = collection_binding_detail.provider_name
|
||||
|
||||
@ -1 +1,3 @@
|
||||
from .rate_limit import RateLimit
|
||||
|
||||
__all__ = ["RateLimit"]
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -19,7 +19,7 @@ class RateLimit:
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict: dict[str, "RateLimit"] = {}
|
||||
|
||||
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
|
||||
def __new__(cls, client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
instance = super().__new__(cls)
|
||||
cls._instance_dict[client_id] = instance
|
||||
@ -63,7 +63,7 @@ class RateLimit:
|
||||
if timeout_requests:
|
||||
redis_client.hdel(self.active_requests_key, *timeout_requests)
|
||||
|
||||
def enter(self, request_id: Optional[str] = None) -> str:
|
||||
def enter(self, request_id: str | None = None) -> str:
|
||||
if self.disabled():
|
||||
return RateLimit._UNLIMITED_REQUEST_ID
|
||||
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -38,11 +37,11 @@ class BasedGenerateTaskPipeline:
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
self.start_at = time.perf_counter()
|
||||
self.output_moderation_handler = self._init_output_moderation()
|
||||
self.stream = stream
|
||||
|
||||
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
|
||||
def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
err: Exception
|
||||
@ -86,7 +85,7 @@ class BasedGenerateTaskPipeline:
|
||||
|
||||
return message
|
||||
|
||||
def _error_to_stream_response(self, e: Exception):
|
||||
def error_to_stream_response(self, e: Exception):
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
@ -94,14 +93,14 @@ class BasedGenerateTaskPipeline:
|
||||
"""
|
||||
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
|
||||
|
||||
def _ping_stream_response(self) -> PingStreamResponse:
|
||||
def ping_stream_response(self) -> PingStreamResponse:
|
||||
"""
|
||||
Ping stream response.
|
||||
:return:
|
||||
"""
|
||||
return PingStreamResponse(task_id=self._application_generate_entity.task_id)
|
||||
|
||||
def _init_output_moderation(self) -> Optional[OutputModeration]:
|
||||
def _init_output_moderation(self) -> OutputModeration | None:
|
||||
"""
|
||||
Init output moderation.
|
||||
:return:
|
||||
@ -118,21 +117,21 @@ class BasedGenerateTaskPipeline:
|
||||
)
|
||||
return None
|
||||
|
||||
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||
def handle_output_moderation_when_task_finished(self, completion: str) -> str | None:
|
||||
"""
|
||||
Handle output moderation when task finished.
|
||||
:param completion: completion
|
||||
:return:
|
||||
"""
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
completion, flagged = self._output_moderation_handler.moderation_completion(
|
||||
completion, flagged = self.output_moderation_handler.moderation_completion(
|
||||
completion=completion, public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
self.output_moderation_handler = None
|
||||
if flagged:
|
||||
return completion
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -109,7 +109,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
task_state=self._task_state,
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread: Optional[Thread] = None
|
||||
self._conversation_name_generate_thread: Thread | None = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
if self.stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation_mode == AppMode.COMPLETION.value:
|
||||
if self._conversation_mode == AppMode.COMPLETION:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
@ -209,7 +209,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self, trace_manager: TraceQueueManager | None = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
@ -252,7 +252,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
|
||||
self, publisher: AppGeneratorTTSPublisher | None, trace_manager: TraceQueueManager | None = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
with Session(db.engine) as session:
|
||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
err = self.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
session.commit()
|
||||
yield self._error_to_stream_response(err)
|
||||
yield self.error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
self._handle_stop(event)
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(
|
||||
output_moderation_answer = self.handle_output_moderation_when_task_finished(
|
||||
cast(str, self._task_state.llm_result.message.content)
|
||||
)
|
||||
if output_moderation_answer:
|
||||
@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self.ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
if publisher:
|
||||
@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None):
|
||||
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.provider_response_latency = time.perf_counter() - self.start_at
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
# transform usage
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
@ -466,14 +466,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
|
||||
)
|
||||
|
||||
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
|
||||
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> AgentThoughtStreamResponse | None:
|
||||
"""
|
||||
Agent thought to stream response.
|
||||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
agent_thought: MessageAgentThought | None = (
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
|
||||
@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
if self.output_moderation_handler:
|
||||
if self.output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||
self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
|
||||
self.queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=LLMResultChunk(
|
||||
@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(text)
|
||||
self.output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from threading import Thread
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
@ -52,7 +52,7 @@ class MessageCycleManager:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
|
||||
"""
|
||||
Generate conversation name.
|
||||
:param conversation_id: conversation id
|
||||
@ -92,7 +92,7 @@ class MessageCycleManager:
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
if conversation.mode != AppMode.COMPLETION.value:
|
||||
if conversation.mode != AppMode.COMPLETION:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
return
|
||||
@ -111,7 +111,7 @@ class MessageCycleManager:
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> MessageAnnotation | None:
|
||||
"""
|
||||
Handle annotation reply.
|
||||
:param event: event
|
||||
@ -141,7 +141,7 @@ class MessageCycleManager:
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
@ -180,7 +180,7 @@ class MessageCycleManager:
|
||||
return None
|
||||
|
||||
def message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
|
||||
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
|
||||
Reference in New Issue
Block a user