Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -1,12 +1,10 @@
from typing import Optional
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
if not sensitive_word_avoidance_dict:
return None
@ -21,7 +19,7 @@ class SensitiveWordAvoidanceConfigManager:
@classmethod
def validate_and_set_defaults(
cls, tenant_id, config: dict, only_structure_validate: bool = False
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False}
@ -38,7 +36,14 @@ class SensitiveWordAvoidanceConfigManager:
if not only_structure_validate:
typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
if not isinstance(typ, str):
raise ValueError("sensitive_word_avoidance.type must be a string")
sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config")
if sensitive_word_avoidance_config is None:
sensitive_word_avoidance_config = {}
if not isinstance(sensitive_word_avoidance_config, dict):
raise ValueError("sensitive_word_avoidance.config must be a dict")
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)

View File

@ -1,12 +1,10 @@
from typing import Optional
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
class AgentConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[AgentEntity]:
def convert(cls, config: dict) -> AgentEntity | None:
"""
Convert model config to model config

View File

@ -1,5 +1,4 @@
import uuid
from typing import Optional
from core.app.app_config.entities import (
DatasetEntity,
@ -14,7 +13,7 @@ from services.dataset_service import DatasetService
class DatasetConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[DatasetEntity]:
def convert(cls, config: dict) -> DatasetEntity | None:
"""
Convert model config to model config
@ -158,7 +157,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict:
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
"""
Extract dataset config for legacy compatibility

View File

@ -4,8 +4,8 @@ from typing import Any
from core.app.app_config.entities import ModelConfigEntity
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from models.provider_ids import ModelProviderID
class ModelConfigManager:
@ -105,7 +105,7 @@ class ModelConfigManager:
return dict(config), ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict) -> dict:
def validate_model_completion_params(cls, cp: dict):
# model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")

View File

@ -25,10 +25,14 @@ class PromptTemplateConfigManager:
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
text = message.get("text")
if not isinstance(text, str):
raise ValueError("message text must be a string")
role = message.get("role")
if not isinstance(role, str):
raise ValueError("message role must be a string")
chat_prompt_messages.append(
AdvancedChatMessageEntity(
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role))
)
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
@ -66,7 +70,7 @@ class PromptTemplateConfigManager:
:param config: app model config args
"""
if not config.get("prompt_type"):
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config["prompt_type"] not in prompt_type_vals:
@ -86,7 +90,7 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
@ -122,7 +126,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict:
def validate_post_prompt_and_set_defaults(cls, config: dict):
"""
Validate post_prompt and set defaults for prompt feature

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Any, Literal, Optional
from enum import StrEnum, auto
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -17,7 +17,7 @@ class ModelConfigEntity(BaseModel):
provider: str
model: str
mode: Optional[str] = None
mode: str | None = None
parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = Field(default_factory=list)
@ -53,7 +53,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
assistant: str
prompt: str
role_prefix: Optional[RolePrefixEntity] = None
role_prefix: RolePrefixEntity | None = None
class PromptTemplateEntity(BaseModel):
@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel):
Prompt Template Entity.
"""
class PromptType(Enum):
class PromptType(StrEnum):
"""
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = "simple"
ADVANCED = "advanced"
SIMPLE = auto()
ADVANCED = auto()
@classmethod
def value_of(cls, value: str):
@ -84,9 +84,9 @@ class PromptTemplateEntity(BaseModel):
raise ValueError(f"invalid prompt type value {value}")
prompt_type: PromptType
simple_prompt_template: Optional[str] = None
advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
simple_prompt_template: str | None = None
advanced_chat_prompt_template: AdvancedChatPromptTemplateEntity | None = None
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
class VariableEntityType(StrEnum):
@ -112,11 +112,11 @@ class VariableEntity(BaseModel):
type: VariableEntityType
required: bool = False
hide: bool = False
max_length: Optional[int] = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
@field_validator("description", mode="before")
@classmethod
@ -129,6 +129,16 @@ class VariableEntity(BaseModel):
return v or []
class RagPipelineVariableEntity(VariableEntity):
"""
Rag Pipeline Variable Entity.
"""
tooltips: str | None = None
placeholder: str | None = None
belong_to_node_id: str
class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
@ -173,7 +183,7 @@ class ModelConfig(BaseModel):
class Condition(BaseModel):
"""
Conditon detail
Condition detail
"""
name: str
@ -186,8 +196,8 @@ class MetadataFilteringCondition(BaseModel):
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class DatasetRetrieveConfigEntity(BaseModel):
@ -195,14 +205,14 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Config Entity.
"""
class RetrieveStrategy(Enum):
class RetrieveStrategy(StrEnum):
"""
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = "single"
MULTIPLE = "multiple"
SINGLE = auto()
MULTIPLE = auto()
@classmethod
def value_of(cls, value: str):
@ -217,18 +227,18 @@ class DatasetRetrieveConfigEntity(BaseModel):
return mode
raise ValueError(f"invalid retrieve strategy value {value}")
query_variable: Optional[str] = None # Only when app mode is completion
query_variable: str | None = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None
score_threshold: Optional[float] = 0.0
rerank_mode: Optional[str] = "reranking_model"
reranking_model: Optional[dict] = None
weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
top_k: int | None = None
score_threshold: float | None = 0.0
rerank_mode: str | None = "reranking_model"
reranking_model: dict | None = None
weights: dict | None = None
reranking_enabled: bool | None = True
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
class DatasetEntity(BaseModel):
@ -255,8 +265,8 @@ class TextToSpeechEntity(BaseModel):
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
voice: str | None = None
language: str | None = None
class TracingConfigEntity(BaseModel):
@ -269,15 +279,15 @@ class TracingConfigEntity(BaseModel):
class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileUploadConfig] = None
opening_statement: Optional[str] = None
file_upload: FileUploadConfig | None = None
opening_statement: str | None = None
suggested_questions: list[str] = []
suggested_questions_after_answer: bool = False
show_retrieve_source: bool = False
more_like_this: bool = False
speech_to_text: bool = False
text_to_speech: Optional[TextToSpeechEntity] = None
trace_config: Optional[TracingConfigEntity] = None
text_to_speech: TextToSpeechEntity | None = None
trace_config: TracingConfigEntity | None = None
class AppConfig(BaseModel):
@ -288,17 +298,17 @@ class AppConfig(BaseModel):
tenant_id: str
app_id: str
app_mode: AppMode
additional_features: AppAdditionalFeatures
additional_features: AppAdditionalFeatures | None = None
variables: list[VariableEntity] = []
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
class EasyUIBasedAppModelConfigFrom(Enum):
class EasyUIBasedAppModelConfigFrom(StrEnum):
"""
App Model Config From.
"""
ARGS = "args"
ARGS = auto()
APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
@ -313,7 +323,7 @@ class EasyUIBasedAppConfig(AppConfig):
app_model_config_dict: dict
model: ModelConfigEntity
prompt_template: PromptTemplateEntity
dataset: Optional[DatasetEntity] = None
dataset: DatasetEntity | None = None
external_data_variables: list[ExternalDataVariableEntity] = []

View File

@ -26,7 +26,7 @@ class MoreLikeThisConfigManager:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError as e:
except ValidationError:
raise ValueError(
"more_like_this must be of dict type and enabled in more_like_this must be of boolean type"
)

View File

@ -1,4 +1,6 @@
from core.app.app_config.entities import VariableEntity
import re
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from models.workflow import Workflow
@ -20,3 +22,48 @@ class WorkflowVariablesConfigManager:
variables.append(VariableEntity.model_validate(variable))
return variables
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
"""
Convert workflow start variables to variables
:param workflow: workflow instance
"""
variables = []
# get second step node
rag_pipeline_variables = workflow.rag_pipeline_variables
if not rag_pipeline_variables:
return []
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
# get datasource node data
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if datasource_node_data:
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for _, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
variables_map.pop(last_part, None)
if value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
variables_map.pop(last_part, None)
all_second_step_variables = list(variables_map.values())
for item in all_second_step_variables:
if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared":
variables.append(RagPipelineVariableEntity.model_validate(item))
return variables

View File

@ -41,7 +41,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
"""
Validate for advanced chat app model config

View File

@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True
app_config.additional_features.show_retrieve_source = True # type: ignore
workflow_run_id = str(uuid.uuid4())
# init application generate entity
@ -390,7 +390,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None,
conversation: Conversation | None = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
@ -420,7 +420,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
db.session.refresh(conversation)
# get conversation dialogue count
self._dialogue_count = get_thread_messages_length(conversation.id)
# NOTE: dialogue_count should not start from 0,
# because during the first conversation, dialogue_count should be 1.
self._dialogue_count = get_thread_messages_length(conversation.id) + 1
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -450,6 +452,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
@ -461,7 +469,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from),
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
@ -475,7 +483,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id: str,
context: contextvars.Context,
variable_loader: VariableLoader,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app

View File

@ -1,11 +1,11 @@
import logging
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable, WorkflowType
from models.workflow import ConversationVariable
logger = logging.getLogger(__name__)
@ -54,7 +55,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow: Workflow,
system_user_id: str,
app: App,
) -> None:
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
@ -68,31 +69,22 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self.system_user_id = system_user_id
self._app = app
def run(self) -> None:
def run(self):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
if not app_record:
raise ValueError("App not found")
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs
@ -144,16 +136,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
# init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
)
db.session.close()
# RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
@ -165,11 +168,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
generator = workflow_entry.run(
callbacks=workflow_callbacks,
)
generator = workflow_entry.run()
for event in generator:
self._handle_event(workflow_entry, event)
@ -219,7 +222,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
return False
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy):
"""
Direct output
"""
@ -229,7 +232,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record

View File

@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, Any] = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -120,6 +120,6 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -4,7 +4,7 @@ import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from threading import Thread
from typing import Any, Optional, Union
from typing import Any, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -31,14 +31,9 @@ from core.app.entities.queue_entities import (
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
@ -65,15 +60,14 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
@ -102,7 +96,7 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
@ -175,7 +169,7 @@ class AdvancedChatAppGenerateTaskPipeline:
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -234,7 +228,7 @@ class AdvancedChatAppGenerateTaskPipeline:
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
@ -290,12 +284,12 @@ class AdvancedChatAppGenerateTaskPipeline:
session.rollback()
raise
def _ensure_workflow_initialized(self) -> None:
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
@ -303,21 +297,16 @@ class AdvancedChatAppGenerateTaskPipeline:
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
with self._database_session() as session:
err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline._error_to_stream_response(err)
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
) -> Generator[StreamResponse, None, None]:
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
# Override graph runtime state - this is a side effect but necessary
graph_runtime_state = event.graph_runtime_state
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
@ -338,15 +327,14 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_retry_resp:
yield node_retry_resp
@ -380,13 +368,12 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
@ -395,9 +382,7 @@ class AdvancedChatAppGenerateTaskPipeline:
def _handle_node_failed_events(
self,
event: Union[
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
@ -419,8 +404,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueTextChunkEvent,
*,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
@ -442,32 +427,6 @@ class AdvancedChatAppGenerateTaskPipeline:
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
@ -546,8 +505,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
@ -577,8 +536,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
@ -609,8 +568,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowFailedEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed events."""
@ -635,17 +594,17 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
yield workflow_finish_resp
yield self._base_task_pipeline._error_to_stream_response(err)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_stop_event(
self,
event: QueueStopEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle stop events."""
@ -685,13 +644,13 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
graph_runtime_state: GraphRuntimeState | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
self._ensure_graph_runtime_initialized(graph_runtime_state)
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer:
@ -759,8 +718,6 @@ class AdvancedChatAppGenerateTaskPipeline:
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
@ -783,10 +740,10 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: Any,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
) -> Generator[StreamResponse, None, None]:
"""Dispatch events using elegant pattern matching."""
handlers = self._get_event_handlers()
@ -808,8 +765,6 @@ class AdvancedChatAppGenerateTaskPipeline:
event,
(
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
),
):
@ -822,31 +777,20 @@ class AdvancedChatAppGenerateTaskPipeline:
)
return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# For unhandled events, we continue (original behavior)
return
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 57-if-statement version.
"""
# Initialize graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None
graph_runtime_state: GraphRuntimeState | None = None
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
@ -856,11 +800,6 @@ class AdvancedChatAppGenerateTaskPipeline:
graph_runtime_state = event.graph_runtime_state
yield from self._handle_workflow_started_event(event)
case QueueTextChunkEvent():
yield from self._handle_text_chunk_event(
event, tts_publisher=tts_publisher, queue_message=queue_message
)
case QueueErrorEvent():
yield from self._handle_error_event(event)
break
@ -896,7 +835,7 @@ class AdvancedChatAppGenerateTaskPipeline:
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
@ -907,7 +846,7 @@ class AdvancedChatAppGenerateTaskPipeline:
message.answer = answer_text
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [
MessageFile(
@ -939,10 +878,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self._task_state.metadata.usage = usage
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""
@ -967,9 +902,9 @@ class AdvancedChatAppGenerateTaskPipeline:
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
if self._base_task_pipeline.output_moderation_handler:
if self._base_task_pipeline.output_moderation_handler.should_direct_output():
self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output()
self._base_task_pipeline.queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
@ -979,7 +914,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
return True
else:
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
self._base_task_pipeline.output_moderation_handler.append_new_token(text)
return False

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any, cast
from core.agent.entities import AgentEntity
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
@ -30,7 +30,7 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
Agent Chatbot App Config Entity.
"""
agent: Optional[AgentEntity] = None
agent: AgentEntity | None = None
class AgentChatAppConfigManager(BaseAppConfigManager):
@ -39,8 +39,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
@ -86,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict:
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
"""
Validate for agent chat app model config
@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return filtered_config
@classmethod
def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_agent_mode_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate agent_mode and set defaults for agent feature
@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
if not isinstance(config["agent_mode"], dict):
agent_mode = config["agent_mode"]
if not isinstance(agent_mode, dict):
raise ValueError("agent_mode must be of object type")
if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]:
config["agent_mode"]["enabled"] = False
# FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing
agent_mode = cast(dict[str, Any], agent_mode)
if not isinstance(config["agent_mode"]["enabled"], bool):
if "enabled" not in agent_mode or not agent_mode["enabled"]:
agent_mode["enabled"] = False
if not isinstance(agent_mode["enabled"], bool):
raise ValueError("enabled in agent_mode must be of boolean type")
if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [
member.value for member in list(PlanningStrategy.__members__.values())
]:
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = []
if not agent_mode.get("tools"):
agent_mode["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list):
if not isinstance(agent_mode["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects")
for tool in config["agent_mode"]["tools"]:
for tool in agent_mode["tools"]:
key = list(tool.keys())[0]
if key in OLD_TOOLS:
# old style, use tool name as key

View File

@ -222,7 +222,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
@ -33,7 +35,7 @@ class AgentChatAppRunner(AppRunner):
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
):
"""
Run assistant application
:param application_generate_entity: application generate entity
@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
if not app_record:
raise ValueError("App not found")
@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).where(Message.id == message.id).first()
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
if message_result is None:
raise ValueError("Message not found")
db.session.close()

View File

@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -94,7 +94,7 @@ class AppGenerateResponseConverter(ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict:
def _error_to_stream_response(cls, e: Exception):
"""
Error to stream response.
:param e: exception

View File

@ -1,12 +1,12 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union, final
from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileUploadConfig
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaver,
DraftVariableSaverFactory,
@ -14,6 +14,7 @@ from core.workflow.repositories.draft_variable_repository import (
)
from factories import file_factory
from libs.orjson import orjson_dumps
from models import Account, EndUser
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
@ -24,7 +25,7 @@ class BaseAppGenerator:
def _prepare_user_inputs(
self,
*,
user_inputs: Optional[Mapping[str, Any]],
user_inputs: Mapping[str, Any] | None,
variables: Sequence["VariableEntity"],
tenant_id: str,
strict_type_validation: bool = False,
@ -44,9 +45,9 @@ class BaseAppGenerator:
mapping=v,
tenant_id=tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
strict_type_validation=strict_type_validation,
)
@ -59,9 +60,9 @@ class BaseAppGenerator:
mappings=v,
tenant_id=tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
)
for k, v in user_inputs.items()
@ -157,7 +158,7 @@ class BaseAppGenerator:
return value
def _sanitize_value(self, value: Any) -> Any:
def _sanitize_value(self, value: Any):
if isinstance(value, str):
return value.replace("\x00", "")
return value
@ -182,8 +183,9 @@ class BaseAppGenerator:
@final
@staticmethod
def _get_draft_var_saver_factory(invoke_from: InvokeFrom) -> DraftVariableSaverFactory:
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER:
assert isinstance(account, Account)
def draft_var_saver_factory(
session: Session,
@ -200,6 +202,7 @@ class BaseAppGenerator:
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
user=account,
)
else:

View File

@ -2,7 +2,7 @@ import queue
import time
from abc import abstractmethod
from enum import IntEnum, auto
from typing import Any, Optional
from typing import Any
from sqlalchemy.orm import DeclarativeMeta
@ -25,13 +25,14 @@ class PublishFrom(IntEnum):
class AppQueueManager:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom):
if not user_id:
raise ValueError("user is required")
self._task_id = task_id
self._user_id = user_id
self._invoke_from = invoke_from
self.invoke_from = invoke_from # Public accessor for invoke_from
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex(
@ -73,14 +74,14 @@ class AppQueueManager:
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10
def stop_listen(self) -> None:
def stop_listen(self):
"""
Stop listen to queue
:return:
"""
self._q.put(None)
def publish_error(self, e, pub_from: PublishFrom) -> None:
def publish_error(self, e, pub_from: PublishFrom):
"""
Publish error
:param e: error
@ -89,7 +90,7 @@ class AppQueueManager:
"""
self.publish(QueueErrorEvent(error=e), pub_from)
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
@ -100,7 +101,7 @@ class AppQueueManager:
self._publish(event, pub_from)
@abstractmethod
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:
@ -110,12 +111,12 @@ class AppQueueManager:
raise NotImplementedError
@classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str):
"""
Set task stop flag
:return:
"""
result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
result: Any | None = redis_client.get(cls._generate_task_belong_cache_key(task_id))
if result is None:
return
@ -126,6 +127,21 @@ class AppQueueManager:
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
@classmethod
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
"""
Set task stop flag without user permission check.
This method allows stopping workflows without user context.
:param task_id: The task ID to stop
:return:
"""
if not task_id:
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
redis_client.setex(stopped_cache_key, 600, 1)
def _is_stopped(self) -> bool:
"""
Check if task is stopped
@ -159,7 +175,7 @@ class AppQueueManager:
def _check_for_sqlalchemy_models(self, data: Any):
# from entity to dict or list
if isinstance(data, dict):
for key, value in data.items():
for value in data.values():
self._check_for_sqlalchemy_models(value)
elif isinstance(data, list):
for item in data:

View File

@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -82,11 +82,11 @@ class AppRunner:
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
query: str | None = None,
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[str] | None]:
"""
Organize prompt messages
:param context:
@ -161,8 +161,8 @@ class AppRunner:
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None,
) -> None:
usage: LLMUsage | None = None,
):
"""
Direct output
:param queue_manager: application queue manager
@ -204,7 +204,7 @@ class AppRunner:
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
) -> None:
):
"""
Handle invoke result
:param invoke_result: invoke result
@ -220,9 +220,7 @@ class AppRunner:
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
) -> None:
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
"""
Handle invoke result direct
:param invoke_result: invoke result
@ -239,7 +237,7 @@ class AppRunner:
def _handle_invoke_result_stream(
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
) -> None:
):
"""
Handle invoke result
:param invoke_result: invoke result
@ -377,7 +375,7 @@ class AppRunner:
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record

View File

@ -1,5 +1,3 @@
from typing import Optional
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@ -32,8 +30,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
conversation: Conversation | None = None,
override_config_dict: dict | None = None,
) -> ChatAppConfig:
"""
Convert app model config to chat app config
@ -81,7 +79,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
def config_validate(cls, tenant_id: str, config: dict):
"""
Validate for chat app model config

View File

@ -211,7 +211,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig
@ -31,7 +33,7 @@ class ChatAppRunner(AppRunner):
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
):
"""
Run application
:param application_generate_entity: application generate entity
@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")
@ -162,7 +164,9 @@ class ChatAppRunner(AppRunner):
config=app_config.dataset,
query=query,
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
show_retrieve_source=(
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
),
hit_callback=hit_callback,
memory=memory,
message_id=message.id,

View File

@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -1,7 +1,7 @@
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from typing import Any, Union
from sqlalchemy.orm import Session
@ -16,14 +16,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
@ -36,24 +31,23 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import (
Account,
EndUser,
)
from services.variable_truncator import VariableTruncator
class WorkflowResponseConverter:
@ -62,9 +56,10 @@ class WorkflowResponseConverter:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
) -> None:
):
self._application_generate_entity = application_generate_entity
self._user = user
self._truncator = VariableTruncator.default()
def workflow_start_to_stream_response(
self,
@ -140,7 +135,7 @@ class WorkflowResponseConverter:
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]:
) -> NodeStartStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
@ -156,7 +151,8 @@ class WorkflowResponseConverter:
title=workflow_node_execution.title,
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
@ -171,11 +167,19 @@ class WorkflowResponseConverter:
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
provider_type=ToolProviderType(event.provider_type),
provider_id=event.provider_id,
)
elif event.node_type == NodeType.DATASOURCE:
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
)
return response
@ -183,14 +187,10 @@ class WorkflowResponseConverter:
def workflow_node_finish_to_stream_response(
self,
*,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
) -> NodeFinishStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
@ -210,9 +210,12 @@ class WorkflowResponseConverter:
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@ -221,9 +224,6 @@ class WorkflowResponseConverter:
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
@ -235,7 +235,7 @@ class WorkflowResponseConverter:
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
@ -255,9 +255,12 @@ class WorkflowResponseConverter:
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@ -275,50 +278,6 @@ class WorkflowResponseConverter:
),
)
def workflow_parallel_branch_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunStartedEvent,
) -> ParallelBranchStartStreamResponse:
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=int(time.time()),
),
)
def workflow_parallel_branch_finished_to_stream_response(
self,
*,
task_id: str,
workflow_execution_id: str,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
),
)
def workflow_iteration_start_to_stream_response(
self,
*,
@ -326,6 +285,7 @@ class WorkflowResponseConverter:
workflow_execution_id: str,
event: QueueIterationStartEvent,
) -> IterationNodeStartStreamResponse:
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@ -333,13 +293,12 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
@ -357,15 +316,10 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
@ -377,6 +331,11 @@ class WorkflowResponseConverter:
event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
json_converter.to_json_encodable(event.outputs) or {}
)
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@ -384,28 +343,29 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=json_converter.to_json_encodable(event.outputs),
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
inputs=new_inputs,
inputs_truncated=inputs_truncated,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def workflow_loop_start_to_stream_response(
self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse:
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
return LoopNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@ -413,10 +373,11 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
@ -437,15 +398,16 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
title=event.node_title,
index=event.index,
pre_loop_output=event.output,
# The `pre_loop_output` field is not utilized by the frontend.
# Previously, it was assigned the value of `event.output`.
pre_loop_output={},
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
@ -456,6 +418,11 @@ class WorkflowResponseConverter:
workflow_execution_id: str,
event: QueueLoopCompletedEvent,
) -> LoopNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
json_converter.to_json_encodable(event.outputs) or {}
)
return LoopNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@ -463,17 +430,19 @@ class WorkflowResponseConverter:
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
inputs=new_inputs,
inputs_truncated=inputs_truncated,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,

View File

@ -1,5 +1,3 @@
from typing import Optional
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
@ -24,7 +22,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
) -> CompletionAppConfig:
"""
Convert app model config to completion app config
@ -66,7 +64,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict:
def config_validate(cls, tenant_id: str, config: dict):
"""
Validate for completion app model config

View File

@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from sqlalchemy import select
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
@ -191,7 +192,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str,
) -> None:
):
"""
Generate worker in a new thread.
:param flask_app: Flask app
@ -248,28 +249,30 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
message = (
db.session.query(Message)
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
stmt = select(Message).where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
message = db.session.scalar(stmt)
if not message:
raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config
if not current_app_model_config:
raise MoreLikeThisDisabledError()
more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
if not app_model_config:
raise ValueError("Message app_model_config is None")
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params")

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig
@ -25,7 +27,7 @@ class CompletionAppRunner(AppRunner):
def run(
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
) -> None:
):
"""
Run application
:param application_generate_entity: application generate entity
@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")
@ -122,7 +124,9 @@ class CompletionAppRunner(AppRunner):
config=dataset_config,
query=query or "",
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
show_retrieve_source=app_config.additional_features.show_retrieve_source
if app_config.additional_features
else False,
hit_callback=hit_callback,
message_id=message.id,
inputs=inputs,

View File

@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
else:
response["metadata"] = {}
return response
@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -1,7 +1,10 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
@ -81,13 +84,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
raise e
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
if conversation:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
stmt = select(AppModelConfig).where(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
app_model_config = db.session.scalar(stmt)
if not app_model_config:
raise AppModelConfigBrokenError()
@ -110,7 +112,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
conversation: Optional[Conversation] = None,
conversation: Conversation | None = None,
) -> tuple[Conversation, Message]:
"""
Initialize generate records
@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
with Session(db.engine, expire_on_commit=False) as session:
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
message = db.session.query(Message).where(Message.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message = session.scalar(select(Message).where(Message.id == message_id))
if message is None:
raise MessageNotExistsError("Message not exists")

View File

@ -14,14 +14,14 @@ from core.app.entities.queue_entities import (
class MessageBasedAppQueueManager(AppQueueManager):
def __init__(
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
) -> None:
):
super().__init__(task_id, user_id, invoke_from)
self._conversation_id = str(conversation_id)
self._app_mode = app_mode
self._message_id = str(message_id)
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:

View File

View File

@ -0,0 +1,95 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
else:
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
else:
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk

View File

@ -0,0 +1,66 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.dataset import Pipeline
from models.model import AppMode
from models.workflow import Workflow
class PipelineConfig(WorkflowUIBasedAppConfig):
"""
Pipeline Config Entity.
"""
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass
class PipelineConfigManager(BaseAppConfigManager):
@classmethod
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig:
pipeline_config = PipelineConfig(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
workflow=workflow, start_node_id=start_node_id
),
)
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for pipeline config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: only validate the structure of the config
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@ -0,0 +1,856 @@
import contextvars
import datetime
import json
import logging
import secrets
import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
)
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories.factory import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
class PipelineGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Generator[Mapping | str, None, None]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: str | None = None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# Add null check for dataset
with Session(db.engine, expire_on_commit=False) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
)
documents: list[Document] = []
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
from services.dataset_service import DocumentService
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=dataset.id,
built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=dataset.chunk_structure,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# run in child thread
rag_pipeline_invoke_entities = []
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = args.get("original_document_id") or None
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
document_id = document_id or documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id,
input_data=inputs,
pipeline_id=pipeline.id,
created_by=user.id,
)
db.session.add(document_pipeline_execution_log)
db.session.commit()
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=dataset.id,
original_document_id=args.get("original_document_id"),
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=pipeline_config.rag_pipeline_variables,
tenant_id=pipeline.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=[],
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
workflow_execution_id=workflow_run_id,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
if invoke_from == InvokeFrom.DEBUGGER or is_retry:
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
context=contextvars.copy_context(),
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
rag_pipeline_invoke_entities.append(
RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
user_id=user.id,
tenant_id=pipeline.tenant_id,
workflow_id=workflow.id,
streaming=streaming,
workflow_execution_id=workflow_run_id,
workflow_thread_pool_id=workflow_thread_pool_id,
application_generate_entity=application_generate_entity.model_dump(),
)
)
if rag_pipeline_invoke_entities:
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
# return batch, dataset, documents
return {
"batch": batch,
"dataset": PipelineDataset(
id=dataset.id,
name=dataset.name,
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [
PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump()
for document in documents
],
}
def _generate(
self,
*,
flask_app: Flask,
context: contextvars.Context,
pipeline: Pipeline,
workflow_id: str,
user: Union[Account, EndUser],
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param pipeline: Pipeline
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=AppMode.RAG_PIPELINE,
)
context = contextvars.copy_context()
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"queue_manager": queue_manager,
"application_generate_entity": application_generate_entity,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
},
)
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
user,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
draft_var_saver_factory=draft_var_saver_factory,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
# init application generate entity - use RagPipelineGenerateEntity instead
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
dataset_id=dataset.id,
batch=args.get("batch", ""),
document_id=args.get("document_id"),
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
workflow_execution_id=str(uuid.uuid4()),
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def single_loop_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
with Session(db.engine) as session:
dataset = pipeline.retrieve_dataset(session)
if not dataset:
raise ValueError("Pipeline dataset is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
)
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
batch=args.get("batch", ""),
document_id=args.get("document_id"),
dataset_id=dataset.id,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
context=contextvars.copy_context(),
)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
try:
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(
select(EndUser).where(EndUser.id == application_generate_entity.user_id)
)
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
# workflow app
runner = PipelineRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
)
runner.run()
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_response(
self,
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(
"Fails to process generate task pipeline, task_id: %r",
application_generate_entity.task_id,
)
raise e
def _build_document(
self,
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: str,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Union[Account, EndUser],
batch: str,
document_form: str,
):
if datasource_type == "local_file":
name = datasource_info.get("name", "untitled")
elif datasource_type == "online_document":
name = datasource_info.get("page", {}).get("page_name", "untitled")
elif datasource_type == "website_crawl":
name = datasource_info.get("title", "untitled")
elif datasource_type == "online_drive":
name = datasource_info.get("name", "untitled")
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type=datasource_type,
data_source_info=json.dumps(datasource_info),
batch=batch,
name=name,
created_from=created_from,
created_by=account.id,
doc_form=document_form,
)
doc_metadata = {}
if built_in_field_enabled:
doc_metadata = {
BuiltInField.document_name: name,
BuiltInField.uploader: account.name,
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.source: datasource_type,
}
if doc_metadata:
document.doc_metadata = doc_metadata
return document
def _format_datasource_info_list(
self,
datasource_type: str,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Union[Account, EndUser],
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.
"""
if datasource_type == "online_drive":
all_files: list[Mapping[str, Any]] = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == start_node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
credential_id=datasource_node_data.get("credential_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
for datasource_info in datasource_info_list:
if datasource_info.get("id") and datasource_info.get("type") == "folder":
# get all files in the folder
self._get_files_in_folder(
datasource_runtime,
datasource_info.get("id", ""),
datasource_info.get("bucket", None),
user.id,
all_files,
datasource_info,
None,
)
else:
all_files.append(
{
"id": datasource_info.get("id", ""),
"name": datasource_info.get("name", "untitled"),
"bucket": datasource_info.get("bucket", None),
}
)
return all_files
else:
return datasource_info_list
def _get_files_in_folder(
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
bucket: str | None,
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: dict | None = None,
):
"""
Get files in a folder.
"""
result_generator = datasource_runtime.online_drive_browse_files(
user_id=user_id,
request=OnlineDriveBrowseFilesRequest(
bucket=bucket,
prefix=prefix,
max_keys=20,
next_page_parameters=next_page_parameters,
),
provider_type=datasource_runtime.datasource_provider_type(),
)
is_truncated = False
for result in result_generator:
for files in result.result:
for file in files.files:
if file.type == "folder":
self._get_files_in_folder(
datasource_runtime,
file.id,
bucket,
user_id,
all_files,
datasource_info,
None,
)
else:
all_files.append(
{
"id": file.id,
"name": file.name,
"bucket": bucket,
}
)
is_truncated = files.is_truncated
next_page_parameters = files.next_page_parameters
if is_truncated:
self._get_files_in_folder(
datasource_runtime, prefix, bucket, user_id, all_files, datasource_info, next_page_parameters
)

View File

@ -0,0 +1,45 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
class PipelineQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
self._q.put(message)
if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

View File

@ -0,0 +1,263 @@
import logging
import time
from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.enums import UserFrom
from models.model import EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class PipelineRunner(WorkflowBasedAppRunner):
"""
Pipeline Application Runner
"""
def __init__(
self,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = SystemVariable(
files=files,
user_id=user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
document_id=self.application_generate_entity.document_id,
original_document_id=self.application_generate_entity.original_document_id,
batch=self.application_generate_entity.batch,
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=self.application_generate_entity.invoke_from.value,
)
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")
) and rag_pipeline_variable.variable in inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=inputs[rag_pipeline_variable.variable],
)
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
rag_pipeline_variables=rag_pipeline_variables,
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_rag_pipeline_graph(
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
workflow=workflow,
)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
variable_pool=variable_pool,
)
generator = workflow_entry.run()
for event in generator:
self._update_document_status(
event, self.application_generate_entity.document_id, self.application_generate_entity.dataset_id
)
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.first()
)
# return workflow
return workflow
def _init_rag_pipeline_graph(
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
) -> Graph:
"""
Init pipeline graph
"""
graph_config = workflow.graph_dict
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# nodes = graph_config.get("nodes", [])
# edges = graph_config.get("edges", [])
# real_run_nodes = []
# real_edges = []
# exclude_node_ids = []
# for node in nodes:
# node_id = node.get("id")
# node_type = node.get("data", {}).get("type", "")
# if node_type == "datasource":
# if start_node_id != node_id:
# exclude_node_ids.append(node_id)
# continue
# real_run_nodes.append(node)
# for edge in edges:
# if edge.get("source") in exclude_node_ids:
# continue
# real_edges.append(edge)
# graph_config = dict(graph_config)
# graph_config["nodes"] = real_run_nodes
# graph_config["edges"] = real_edges
# init graph
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
if not graph:
raise ValueError("graph not found in workflow")
return graph
def _update_document_status(self, event: GraphEngineEvent, document_id: str | None, dataset_id: str | None) -> None:
"""
Update document status
"""
if isinstance(event, GraphRunFailedEvent):
if document_id and dataset_id:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"
document.error = event.error or "Unknown error"
db.session.add(document)
db.session.commit()

View File

@ -35,7 +35,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
"""
Validate for workflow app model config

View File

@ -53,7 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Generator[Mapping | str, None, None]: ...
@ -69,7 +68,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Mapping[str, Any]: ...
@ -85,7 +83,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
@ -100,7 +97,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
@ -200,7 +196,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
root_node_id=root_node_id,
)
@ -215,7 +210,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
@ -230,7 +224,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
@ -253,7 +246,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
},
@ -261,9 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
)
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
# return response or stream generator
response = self._handle_response(
@ -451,7 +441,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
root_node_id: Optional[str] = None,
) -> None:
"""
@ -462,7 +451,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
@ -492,7 +480,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,

View File

@ -14,12 +14,12 @@ from core.app.entities.queue_entities import (
class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str):
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:

View File

@ -1,7 +1,7 @@
import logging
from typing import Optional, cast
import time
from typing import cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from models.enums import UserFrom
from models.workflow import Workflow, WorkflowType
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -31,47 +32,33 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow,
system_user_id: str,
root_node_id: Optional[str] = None,
) -> None:
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
self._root_node_id = root_node_id
def run(self) -> None:
def run(self):
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
)
else:
inputs = self.application_generate_entity.inputs
@ -94,15 +81,28 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict, root_node_id=self._root_node_id)
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
root_node_id=self._root_node_id
)
# RUN WORKFLOW
# Create Redis command channel for this workflow execution
task_id = self.application_generate_entity.task_id
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
@ -114,10 +114,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id,
graph_runtime_state=graph_runtime_state,
command_channel=command_channel,
)
generator = workflow_entry.run(callbacks=workflow_callbacks)
generator = workflow_entry.run()
for event in generator:
self._handle_event(workflow_entry, event)

View File

@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())
return blocking_response.model_dump()
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield "ping"
continue
response_chunk = {
response_chunk: dict[str, object] = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import Any, Optional, Union
from typing import Union
from sqlalchemy.orm import Session
@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
@ -25,14 +26,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
@ -57,8 +53,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -92,7 +88,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
@ -137,7 +133,7 @@ class WorkflowAppGenerateTaskPipeline:
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
self._invoke_from = queue_manager._invoke_from
self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -146,7 +142,7 @@ class WorkflowAppGenerateTaskPipeline:
:return:
"""
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._base_task_pipeline._stream:
if self._base_task_pipeline.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -206,7 +202,7 @@ class WorkflowAppGenerateTaskPipeline:
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
@ -263,12 +259,12 @@ class WorkflowAppGenerateTaskPipeline:
session.rollback()
raise
def _ensure_workflow_initialized(self) -> None:
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
@ -276,12 +272,12 @@ class WorkflowAppGenerateTaskPipeline:
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
yield self._base_task_pipeline.ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
err = self._base_task_pipeline.handle_error(event=event)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, **kwargs
@ -300,16 +296,15 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
@ -350,9 +345,7 @@ class WorkflowAppGenerateTaskPipeline:
def _handle_node_failed_events(
self,
event: Union[
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
@ -371,32 +364,6 @@ class WorkflowAppGenerateTaskPipeline:
if node_failed_response:
yield node_failed_response
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
@ -475,8 +442,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
@ -509,8 +476,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
@ -544,8 +511,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed and stop events."""
@ -582,8 +549,8 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueTextChunkEvent,
*,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
@ -618,8 +585,6 @@ class WorkflowAppGenerateTaskPipeline:
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
@ -634,12 +599,12 @@ class WorkflowAppGenerateTaskPipeline:
def _dispatch_event(
self,
event: Any,
event: AppQueueEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
) -> Generator[StreamResponse, None, None]:
"""Dispatch events using elegant pattern matching."""
handlers = self._get_event_handlers()
@ -661,8 +626,6 @@ class WorkflowAppGenerateTaskPipeline:
event,
(
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
),
):
@ -675,17 +638,6 @@ class WorkflowAppGenerateTaskPipeline:
)
return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle workflow failed and stop events with isinstance check
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
yield from self._handle_workflow_failed_and_stop_events(
@ -702,8 +654,8 @@ class WorkflowAppGenerateTaskPipeline:
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response using elegant Fluent Python patterns.
@ -745,7 +697,7 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
@ -770,7 +722,7 @@ class WorkflowAppGenerateTaskPipeline:
session.commit()
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
self, text: str, from_variable_selector: list[str] | None = None
) -> TextChunkStreamResponse:
"""
Handle completed event.

View File

@ -1,7 +1,9 @@
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@ -13,14 +15,9 @@ from core.app.entities.queue_entities import (
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
@ -28,42 +25,39 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
LoopRunFailedEvent,
LoopRunNextEvent,
LoopRunStartedEvent,
LoopRunSucceededEvent,
NodeInIterationFailedEvent,
NodeInLoopFailedEvent,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom
from models.workflow import Workflow
@ -74,12 +68,20 @@ class WorkflowBasedAppRunner:
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
) -> None:
):
self._queue_manager = queue_manager
self._variable_loader = variable_loader
self._app_id = app_id
def _init_graph(self, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> Graph:
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
root_node_id: Optional[str] = None
) -> Graph:
"""
Init graph
"""
@ -91,22 +93,109 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=tenant_id or "",
app_id=self._app_id,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
# Use the provided graph_runtime_state for consistent state management
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
if not graph:
raise ValueError("graph not found in workflow")
return graph
def _get_graph_and_variable_pool_of_single_iteration(
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
(either single iteration or single loop).
Args:
workflow: The workflow instance
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
Returns:
A tuple containing (graph, variable_pool, graph_runtime_state)
Raises:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
def _get_graph_and_variable_pool_for_single_node_run(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
Get graph and variable pool for single node execution (iteration or loop).
Args:
workflow: The workflow instance
node_id: The node ID to execute
user_inputs: User inputs for the node
graph_runtime_state: The graph runtime state
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
node_type_label: Label for error messages ('iteration' or 'loop')
Returns:
A tuple containing (graph, variable_pool)
"""
# fetch workflow graph
graph_config = workflow.graph_dict
@ -124,18 +213,22 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in iteration
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in iteration
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
@ -145,37 +238,50 @@ class WorkflowBasedAppRunner:
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
iteration_node_config = None
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
iteration_node_config = node
target_node_config = node
break
if not iteration_node_config:
raise ValueError("iteration node id not found in workflow graph")
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
# Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
node_version = iteration_node_config.get("data", {}).get("version", "1")
node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = target_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=iteration_node_config
graph_config=workflow.graph_dict, config=target_node_config
)
except NotImplementedError:
variable_mapping = {}
@ -196,103 +302,45 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
def _get_graph_and_variable_pool_of_single_loop(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in loop
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in loop
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
loop_node_config = None
for node in node_configs:
if node.get("id") == node_id:
loop_node_config = node
break
if not loop_node_config:
raise ValueError("loop node id not found in workflow graph")
# Get node class
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
node_version = loop_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=loop_node_config
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
:param workflow_entry: workflow entry
@ -310,39 +358,32 @@ class WorkflowBasedAppRunner:
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.route_node_state.node_run_result
inputs: Mapping[str, Any] | None = {}
process_data: Mapping[str, Any] | None = {}
outputs: Mapping[str, Any] | None = {}
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
parallel_mode_run_id=event.parallel_mode_run_id,
inputs=inputs,
process_data=process_data,
outputs=outputs,
error=event.error,
execution_metadata=execution_metadata,
retry_index=event.retry_index,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
elif isinstance(event, NodeRunStartedEvent):
@ -350,44 +391,29 @@ class WorkflowBasedAppRunner:
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_title=event.node_title,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
start_at=event.start_at,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
parallel_mode_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
node_run_result = event.route_node_state.node_run_result
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
node_run_result = event.node_run_result
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
start_at=event.start_at,
inputs=inputs,
process_data=process_data,
outputs=outputs,
@ -396,34 +422,18 @@ class WorkflowBasedAppRunner:
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
start_at=event.start_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
@ -434,93 +444,21 @@ class WorkflowBasedAppRunner:
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
start_at=event.start_at,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=event.node_run_result.outputs,
error=event.node_run_result.error or "Unknown error",
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, NodeInLoopFailedEvent):
self._publish_event(
QueueNodeInLoopFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs or {}
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_loop_id=event.in_loop_id,
error=event.error,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk_content,
from_variable_selector=event.from_variable_selector,
text=event.chunk,
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
@ -533,10 +471,10 @@ class WorkflowBasedAppRunner:
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, AgentLogEvent):
elif isinstance(event, NodeRunAgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.id,
id=event.message_id,
label=event.label,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
@ -547,51 +485,13 @@ class WorkflowBasedAppRunner:
node_id=event.node_id,
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
self._publish_event(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
error=event.error,
)
)
elif isinstance(event, IterationRunStartedEvent):
elif isinstance(event, NodeRunIterationStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
@ -599,55 +499,41 @@ class WorkflowBasedAppRunner:
metadata=event.metadata,
)
)
elif isinstance(event, IterationRunNextEvent):
elif isinstance(event, NodeRunIterationNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
)
)
elif isinstance(event, LoopRunStartedEvent):
elif isinstance(event, NodeRunLoopStartedEvent):
self._publish_event(
QueueLoopStartEvent(
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
@ -655,44 +541,34 @@ class WorkflowBasedAppRunner:
metadata=event.metadata,
)
)
elif isinstance(event, LoopRunNextEvent):
elif isinstance(event, NodeRunLoopNextEvent):
self._publish_event(
QueueLoopNextEvent(
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_loop_output,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
)
)
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
self._publish_event(
QueueLoopCompletedEvent(
node_execution_id=event.loop_id,
node_id=event.loop_node_id,
node_type=event.loop_node_type,
node_data=event.loop_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
)
)
def _publish_event(self, event: AppQueueEvent) -> None:
def _publish_event(self, event: AppQueueEvent):
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

View File

@ -1,9 +1,12 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
@ -11,7 +14,7 @@ from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity
class InvokeFrom(Enum):
class InvokeFrom(StrEnum):
"""
Invoke From.
"""
@ -35,6 +38,7 @@ class InvokeFrom(Enum):
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
PUBLISHED = "published"
@classmethod
def value_of(cls, value: str):
@ -95,8 +99,8 @@ class AppGenerateEntity(BaseModel):
task_id: str
# app config
app_config: Any
file_upload_config: Optional[FileUploadConfig] = None
app_config: Any = None
file_upload_config: FileUploadConfig | None = None
inputs: Mapping[str, Any]
files: Sequence[File]
@ -113,8 +117,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
# Using Any to avoid circular import with TraceQueueManager
trace_manager: Optional[Any] = None
trace_manager: Optional["TraceQueueManager"] = None
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
@ -123,10 +126,10 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: EasyUIBasedAppConfig
app_config: EasyUIBasedAppConfig = None # type: ignore
model_conf: ModelConfigWithCredentialsEntity
query: Optional[str] = None
query: str | None = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@ -137,8 +140,8 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
Base entity for conversation-based app generation.
"""
conversation_id: Optional[str] = None
parent_message_id: Optional[str] = Field(
conversation_id: str | None = None
parent_message_id: str | None = Field(
default=None,
description=(
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
@ -186,9 +189,9 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_run_id: Optional[str] = None
workflow_run_id: str | None = None
query: str
class SingleIterationRunEntity(BaseModel):
@ -199,7 +202,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
node_id: str
inputs: Mapping
single_iteration_run: Optional[SingleIterationRunEntity] = None
single_iteration_run: SingleIterationRunEntity | None = None
class SingleLoopRunEntity(BaseModel):
"""
@ -209,7 +212,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
node_id: str
inputs: Mapping
single_loop_run: Optional[SingleLoopRunEntity] = None
single_loop_run: SingleLoopRunEntity | None = None
class WorkflowAppGenerateEntity(AppGenerateEntity):
@ -218,7 +221,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_execution_id: str
class SingleIterationRunEntity(BaseModel):
@ -229,7 +232,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
single_iteration_run: SingleIterationRunEntity | None = None
class SingleLoopRunEntity(BaseModel):
"""
@ -239,4 +242,35 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
node_id: str
inputs: dict
single_loop_run: Optional[SingleLoopRunEntity] = None
single_loop_run: SingleLoopRunEntity | None = None
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""
RAG Pipeline Application Generate Entity.
"""
# pipeline config
pipeline_config: WorkflowUIBasedAppConfig
datasource_type: str
datasource_info: Mapping[str, Any]
dataset_id: str
batch: str
document_id: str | None = None
original_document_id: str | None = None
start_node_id: str | None = None
# Import TraceQueueManager at runtime to resolve forward references
from core.ops.ops_trace_manager import TraceQueueManager
# Rebuild models that use forward references
AppGenerateEntity.model_rebuild()
EasyUIBasedAppGenerateEntity.model_rebuild()
ConversationAppGenerateEntity.model_rebuild()
ChatAppGenerateEntity.model_rebuild()
CompletionAppGenerateEntity.model_rebuild()
AgentChatAppGenerateEntity.model_rebuild()
AdvancedChatAppGenerateEntity.model_rebuild()
WorkflowAppGenerateEntity.model_rebuild()
RagPipelineGenerateEntity.model_rebuild()

View File

@ -1,17 +1,15 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class QueueEvent(StrEnum):
@ -43,9 +41,6 @@ class QueueEvent(StrEnum):
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log"
ERROR = "error"
PING = "ping"
@ -80,21 +75,13 @@ class QueueIterationStartEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
node_title: str
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
class QueueIterationNextEvent(AppQueueEvent):
@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteration run in parallel mode run id"""
node_title: str
node_run_index: int
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None
output: Any = None # output for the current iteration
class QueueIterationCompletedEvent(AppQueueEvent):
@ -134,24 +110,16 @@ class QueueIterationCompletedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
node_title: str
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: Optional[str] = None
error: str | None = None
class QueueLoopStartEvent(AppQueueEvent):
@ -163,21 +131,21 @@ class QueueLoopStartEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
class QueueLoopNextEvent(AppQueueEvent):
@ -191,20 +159,19 @@ class QueueLoopNextEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current loop
duration: Optional[float] = None
output: Any = None # output for the current loop
class QueueLoopCompletedEvent(AppQueueEvent):
@ -217,24 +184,24 @@ class QueueLoopCompletedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
steps: int = 0
error: Optional[str] = None
error: str | None = None
class QueueTextChunkEvent(AppQueueEvent):
@ -244,11 +211,11 @@ class QueueTextChunkEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
from_variable_selector: Optional[list[str]] = None
from_variable_selector: list[str] | None = None
"""from variable selector"""
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
@ -285,9 +252,9 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: Sequence[RetrievalSourceMetadata]
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
@ -306,7 +273,7 @@ class QueueMessageEndEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.MESSAGE_END
llm_result: Optional[LLMResult] = None
llm_result: LLMResult | None = None
class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
@ -332,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: Optional[dict[str, Any]] = None
outputs: Mapping[str, object] = Field(default_factory=dict)
class QueueWorkflowFailedEvent(AppQueueEvent):
@ -352,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int
outputs: Optional[dict[str, Any]] = None
outputs: Mapping[str, object] = Field(default_factory=dict)
class QueueNodeStartedEvent(AppQueueEvent):
@ -364,26 +331,23 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_title: str
node_type: NodeType
node_data: BaseNodeData
node_run_index: int = 1
predecessor_node_id: Optional[str] = None
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
node_run_index: int = 1 # FIXME(-LAN-): may not used
predecessor_node_id: str | None = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
in_iteration_id: str | None = None
in_loop_id: str | None = None
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteration run in parallel mode run id"""
agent_strategy: Optional[AgentNodeStrategyInit] = None
parallel_mode_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
# FIXME(-LAN-): only for ToolNode, need to refactor
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
provider_id: str
class QueueNodeSucceededEvent(AppQueueEvent):
@ -396,31 +360,26 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: Optional[str] = None
"""single iteration duration map"""
iteration_duration_map: Optional[dict[str, float]] = None
"""single loop duration map"""
loop_duration_map: Optional[dict[str, float]] = None
error: str | None = None
class QueueAgentLogEvent(AppQueueEvent):
@ -432,11 +391,11 @@ class QueueAgentLogEvent(AppQueueEvent):
id: str
label: str
node_execution_id: str
parent_id: str | None
error: str | None
parent_id: str | None = None
error: str | None = None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
metadata: Mapping[str, object] = Field(default_factory=dict)
node_id: str
@ -445,81 +404,15 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
event: QueueEvent = QueueEvent.RETRY
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
retry_index: int # retry index
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
error: str
class QueueNodeInLoopFailedEvent(AppQueueEvent):
"""
QueueNodeInLoopFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
error: str
class QueueNodeExceptionEvent(AppQueueEvent):
"""
QueueNodeExceptionEvent entity
@ -530,25 +423,24 @@ class QueueNodeExceptionEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
@ -563,25 +455,17 @@ class QueueNodeFailedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
parallel_id: str | None = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
in_loop_id: str | None = None
"""loop id if node is in loop"""
start_at: datetime
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
inputs: Mapping[str, object] = Field(default_factory=dict)
process_data: Mapping[str, object] = Field(default_factory=dict)
outputs: Mapping[str, object] = Field(default_factory=dict)
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
error: str
@ -610,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.ERROR
error: Optional[Any] = None
error: Any = None
class QueuePingEvent(AppQueueEvent):
@ -626,15 +510,15 @@ class QueueStopEvent(AppQueueEvent):
QueueStopEvent entity
"""
class StopBy(Enum):
class StopBy(StrEnum):
"""
Stop by enum
"""
USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
INPUT_MODERATION = "input-moderation"
USER_MANUAL = auto()
ANNOTATION_REPLY = auto()
OUTPUT_MODERATION = auto()
INPUT_MODERATION = auto()
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy
@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage):
"""
pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
error: str

View File

@ -0,0 +1,14 @@
from typing import Any
from pydantic import BaseModel
class RagPipelineInvokeEntity(BaseModel):
pipeline_id: str
application_generate_entity: dict[str, Any]
user_id: str
tenant_id: str
workflow_id: str
streaming: bool
workflow_execution_id: str | None = None
workflow_thread_pool_id: str | None = None

View File

@ -1,14 +1,13 @@
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
@ -51,7 +50,7 @@ class WorkflowTaskState(TaskState):
answer: str = ""
class StreamEvent(Enum):
class StreamEvent(StrEnum):
"""
Stream event
"""
@ -71,8 +70,6 @@ class StreamEvent(Enum):
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
@ -92,9 +89,6 @@ class StreamResponse(BaseModel):
event: StreamEvent
task_id: str
def to_dict(self):
return jsonable_encoder(self)
class ErrorStreamResponse(StreamResponse):
"""
@ -114,7 +108,7 @@ class MessageStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE
id: str
answer: str
from_variable_selector: Optional[list[str]] = None
from_variable_selector: list[str] | None = None
class MessageAudioStreamResponse(StreamResponse):
@ -142,8 +136,8 @@ class MessageEndStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
metadata: dict = Field(default_factory=dict)
files: Optional[Sequence[Mapping[str, Any]]] = None
metadata: Mapping[str, object] = Field(default_factory=dict)
files: Sequence[Mapping[str, Any]] | None = None
class MessageFileStreamResponse(StreamResponse):
@ -176,12 +170,12 @@ class AgentThoughtStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.AGENT_THOUGHT
id: str
position: int
thought: Optional[str] = None
observation: Optional[str] = None
tool: Optional[str] = None
tool_labels: Optional[dict] = None
tool_input: Optional[str] = None
message_files: Optional[list[str]] = None
thought: str | None = None
observation: str | None = None
tool: str | None = None
tool_labels: Mapping[str, object] = Field(default_factory=dict)
tool_input: str | None = None
message_files: list[str] | None = None
class AgentMessageStreamResponse(StreamResponse):
@ -227,16 +221,16 @@ class WorkflowFinishStreamResponse(StreamResponse):
id: str
workflow_id: str
status: str
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
created_by: Optional[dict] = None
created_by: Mapping[str, object] = Field(default_factory=dict)
created_at: int
finished_at: int
exceptions_count: Optional[int] = 0
files: Optional[Sequence[Mapping[str, Any]]] = []
exceptions_count: int | None = 0
files: Sequence[Mapping[str, Any]] | None = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
workflow_run_id: str
@ -258,18 +252,19 @@ class NodeStartStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
created_at: int
extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
parallel_run_id: Optional[str] = None
agent_strategy: Optional[AgentNodeStrategyInit] = None
extras: dict[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
parallel_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@ -315,23 +310,26 @@ class NodeFinishStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
process_data: Mapping[str, Any] | None = None
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = True
status: str
error: Optional[str] = None
error: str | None = None
elapsed_time: float
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
@ -384,23 +382,26 @@ class NodeRetryStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
process_data: Mapping[str, Any] | None = None
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = False
status: str
error: Optional[str] = None
error: str | None = None
elapsed_time: float
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
retry_index: int = 0
event: StreamEvent = StreamEvent.NODE_RETRY
@ -440,54 +441,6 @@ class NodeRetryStreamResponse(StreamResponse):
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
loop_id: Optional[str] = None
status: str
error: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
@ -506,8 +459,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
inputs_truncated: bool = False
event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str
@ -530,12 +482,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
pre_iteration_output: Optional[Any] = None
extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
duration: Optional[float] = None
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -556,19 +503,19 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[Mapping] = None
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: Optional[dict] = None
inputs: Optional[Mapping] = None
extras: dict | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
error: str | None = None
elapsed_time: float
total_tokens: int
execution_metadata: Optional[Mapping] = None
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int
steps: int
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str
@ -593,8 +540,9 @@ class LoopNodeStartStreamResponse(StreamResponse):
extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
inputs_truncated: bool = False
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_STARTED
workflow_run_id: str
@ -617,12 +565,11 @@ class LoopNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
pre_loop_output: Optional[Any] = None
extras: dict = Field(default_factory=dict)
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
duration: Optional[float] = None
pre_loop_output: Any = None
extras: Mapping[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parallel_mode_run_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_NEXT
workflow_run_id: str
@ -643,19 +590,21 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[Mapping] = None
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: Optional[dict] = None
inputs: Optional[Mapping] = None
extras: dict | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
error: str | None = None
elapsed_time: float
total_tokens: int
execution_metadata: Optional[Mapping] = None
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int
steps: int
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_COMPLETED
workflow_run_id: str
@ -673,7 +622,7 @@ class TextChunkStreamResponse(StreamResponse):
"""
text: str
from_variable_selector: Optional[list[str]] = None
from_variable_selector: list[str] | None = None
event: StreamEvent = StreamEvent.TEXT_CHUNK
data: Data
@ -735,7 +684,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
WorkflowAppStreamResponse entity
"""
workflow_run_id: Optional[str] = None
workflow_run_id: str | None = None
class AppBlockingResponse(BaseModel):
@ -745,9 +694,6 @@ class AppBlockingResponse(BaseModel):
task_id: str
def to_dict(self):
return jsonable_encoder(self)
class ChatbotAppBlockingResponse(AppBlockingResponse):
"""
@ -764,7 +710,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
conversation_id: str
message_id: str
answer: str
metadata: dict = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
data: Data
@ -784,7 +730,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
mode: str
message_id: str
answer: str
metadata: dict = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
data: Data
@ -803,8 +749,8 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
id: str
workflow_id: str
status: str
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
@ -828,11 +774,11 @@ class AgentLogStreamResponse(StreamResponse):
node_execution_id: str
id: str
label: str
parent_id: str | None
error: str | None
parent_id: str | None = None
error: str | None = None
status: str
data: Mapping[str, Any]
metadata: Optional[Mapping[str, Any]] = None
metadata: Mapping[str, object] = Field(default_factory=dict)
node_id: str
event: StreamEvent = StreamEvent.AGENT_LOG

View File

@ -1,5 +1,6 @@
import logging
from typing import Optional
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
class AnnotationReplyFeature:
def query(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
) -> MessageAnnotation | None:
"""
Query app annotations to reply
:param app_record: app record
@ -25,15 +26,17 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from
:return:
"""
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
annotation_setting = db.session.scalar(stmt)
if not annotation_setting:
return None
collection_binding_detail = annotation_setting.collection_binding_detail
if not collection_binding_detail:
return None
try:
score_threshold = annotation_setting.score_threshold or 1
embedding_provider_name = collection_binding_detail.provider_name

View File

@ -1 +1,3 @@
from .rate_limit import RateLimit
__all__ = ["RateLimit"]

View File

@ -3,7 +3,7 @@ import time
import uuid
from collections.abc import Generator, Mapping
from datetime import timedelta
from typing import Any, Optional, Union
from typing import Any, Union
from core.errors.error import AppInvokeQuotaExceededError
from extensions.ext_redis import redis_client
@ -19,7 +19,7 @@ class RateLimit:
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict: dict[str, "RateLimit"] = {}
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
def __new__(cls, client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
@ -63,7 +63,7 @@ class RateLimit:
if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests)
def enter(self, request_id: Optional[str] = None) -> str:
def enter(self, request_id: str | None = None) -> str:
if self.disabled():
return RateLimit._UNLIMITED_REQUEST_ID
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
@ -96,7 +96,11 @@ class RateLimit:
if isinstance(generator, Mapping):
return generator
else:
return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id)
return RateLimitGenerator(
rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type]
request_id=request_id,
)
class RateLimitGenerator:

View File

@ -1,6 +1,5 @@
import logging
import time
from typing import Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -35,14 +34,14 @@ class BasedGenerateTaskPipeline:
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool,
) -> None:
):
self._application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
self.start_at = time.perf_counter()
self.output_moderation_handler = self._init_output_moderation()
self.stream = stream
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
logger.debug("error: %s", event.error)
e = event.error
err: Exception
@ -50,7 +49,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError):
err = e
err = e # ty: ignore [invalid-assignment]
else:
description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e))
@ -86,7 +85,7 @@ class BasedGenerateTaskPipeline:
return message
def _error_to_stream_response(self, e: Exception):
def error_to_stream_response(self, e: Exception):
"""
Error to stream response.
:param e: exception
@ -94,14 +93,14 @@ class BasedGenerateTaskPipeline:
"""
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
def _ping_stream_response(self) -> PingStreamResponse:
def ping_stream_response(self) -> PingStreamResponse:
"""
Ping stream response.
:return:
"""
return PingStreamResponse(task_id=self._application_generate_entity.task_id)
def _init_output_moderation(self) -> Optional[OutputModeration]:
def _init_output_moderation(self) -> OutputModeration | None:
"""
Init output moderation.
:return:
@ -118,21 +117,21 @@ class BasedGenerateTaskPipeline:
)
return None
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
def handle_output_moderation_when_task_finished(self, completion: str) -> str | None:
"""
Handle output moderation when task finished.
:param completion: completion
:return:
"""
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
completion, flagged = self._output_moderation_handler.moderation_completion(
completion, flagged = self.output_moderation_handler.moderation_completion(
completion=completion, public_event=False
)
self._output_moderation_handler = None
self.output_moderation_handler = None
if flagged:
return completion

View File

@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Optional, Union, cast
from typing import Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -80,7 +80,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
conversation: Conversation,
message: Message,
stream: bool,
) -> None:
):
super().__init__(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
@ -109,7 +109,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
task_state=self._task_state,
)
self._conversation_name_generate_thread: Optional[Thread] = None
self._conversation_name_generate_thread: Thread | None = None
def process(
self,
@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self.stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value:
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data(
@ -209,7 +209,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
self, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
@ -252,7 +252,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None
self, publisher: AppGeneratorTTSPublisher | None, trace_manager: TraceQueueManager | None = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self._error_to_stream_response(err)
yield self.error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._handle_stop(event)
# handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished(
output_moderation_answer = self.handle_output_moderation_when_task_finished(
cast(str, self._task_state.llm_result.message.content)
)
if output_moderation_answer:
@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self.ping_stream_response()
else:
continue
if publisher:
@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None:
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
"""
Save message.
:return:
@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.provider_response_latency = time.perf_counter() - self._start_at
message.provider_response_latency = time.perf_counter() - self.start_at
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
@ -412,7 +412,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
application_generate_entity=self._application_generate_entity,
)
def _handle_stop(self, event: QueueStopEvent) -> None:
def _handle_stop(self, event: QueueStopEvent):
"""
Handle stop.
:return:
@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
# transform usage
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
self._task_state.llm_result.usage = model_type_instance.calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
@ -466,14 +466,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
)
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> AgentThoughtStreamResponse | None:
"""
Agent thought to stream response.
:param event: agent thought event
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: Optional[MessageAgentThought] = (
agent_thought: MessageAgentThought | None = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
if self.output_moderation_handler:
if self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output()
self.queue_manager.publish(
QueueLLMChunkEvent(
chunk=LLMResultChunk(
@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
return True
else:
self._output_moderation_handler.append_new_token(text)
self.output_moderation_handler.append_new_token(text)
return False

View File

@ -1,8 +1,10 @@
import logging
from threading import Thread
from typing import Optional, Union
from typing import Union
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import (
@ -46,11 +48,11 @@ class MessageCycleManager:
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
):
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
"""
Generate conversation name.
:param conversation_id: conversation id
@ -84,30 +86,32 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation = db.session.scalar(stmt)
if not conversation:
return
if conversation.mode != AppMode.COMPLETION.value:
if conversation.mode != AppMode.COMPLETION:
app_model = conversation.app
if not app_model:
return
# generate conversation name
try:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, query, conversation_id, conversation.app_id
)
conversation.name = name
except Exception as e:
except Exception:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
pass
db.session.merge(conversation)
db.session.commit()
db.session.close()
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> MessageAnnotation | None:
"""
Handle annotation reply.
:param event: event
@ -128,22 +132,25 @@ class MessageCycleManager:
return None
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent):
"""
Handle retriever resources.
:param event: event
:return:
"""
if not self._application_generate_entity.app_config.additional_features:
raise ValueError("Additional features not found")
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata.retriever_resources = event.retriever_resources
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
"""
Message file to stream response.
:param event: event
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
# get tool file id
@ -175,7 +182,7 @@ class MessageCycleManager:
return None
def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
) -> MessageStreamResponse:
"""
Message to stream response.
@ -183,7 +190,8 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(