mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts: # api/core/app/apps/advanced_chat/app_generator.py # api/core/app/apps/advanced_chat/app_runner.py # api/core/app/apps/advanced_chat/generate_task_pipeline.py # api/core/app/apps/base_app_runner.py # api/core/app/apps/workflow/app_runner.py # api/core/app/apps/workflow/generate_task_pipeline.py # api/core/app/task_pipeline/workflow_cycle_state_manager.py # api/core/workflow/entities/node_entities.py # api/core/workflow/nodes/llm/llm_node.py # api/core/workflow/workflow_engine_manager.py # api/tests/integration_tests/workflow/nodes/test_llm.py # api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py # api/tests/unit_tests/core/workflow/nodes/test_answer.py # api/tests/unit_tests/core/workflow/nodes/test_if_else.py # api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
@ -121,7 +121,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
@ -141,10 +141,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
@ -191,7 +191,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
@ -232,8 +232,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'user': user,
|
||||
'context': contextvars.copy_context()
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
@ -246,7 +245,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||
@ -259,7 +258,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
user: Account,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@ -307,14 +305,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False) \
|
||||
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
def _handle_advanced_chat_response(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
@ -334,7 +335,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -3,9 +3,6 @@ import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
@ -94,7 +91,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
|
||||
@ -44,7 +44,7 @@ from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
@ -69,14 +69,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
@ -127,7 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
|
||||
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
@ -239,7 +239,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
@ -270,9 +270,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
|
||||
@ -307,7 +307,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -316,7 +316,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -325,7 +325,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -334,10 +334,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -360,10 +360,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -400,7 +400,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
|
||||
# Save message
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
|
||||
@ -413,7 +413,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
@ -421,7 +421,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
@ -446,7 +446,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
@ -458,11 +458,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._message_end_to_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
# publish None when task finished
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
@ -507,7 +507,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata.copy()
|
||||
|
||||
|
||||
if 'annotation_reply' in extras['metadata']:
|
||||
del extras['metadata']['annotation_reply']
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
@ -126,7 +128,7 @@ class AppRunner:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
@ -366,7 +368,7 @@ class AppRunner:
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager
|
||||
)
|
||||
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
@ -418,7 +420,7 @@ class AppRunner:
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
|
||||
@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
return introduction
|
||||
|
||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||
def _get_conversation(self, conversation_id: str):
|
||||
"""
|
||||
Get conversation by conversation id
|
||||
:param conversation_id: conversation id
|
||||
@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
|
||||
@ -11,7 +11,8 @@ from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
|
||||
@ -41,7 +41,9 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
@ -179,7 +181,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
@ -246,7 +248,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
|
||||
@ -281,7 +283,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -290,7 +292,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -299,7 +301,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -308,10 +310,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -332,10 +334,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
|
||||
Reference in New Issue
Block a user