Merge branch 'refs/heads/main' into feat/workflow-parallel-support

# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/base_app_runner.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/node_entities.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/workflow_engine_manager.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
takatost
2024-08-16 01:19:29 +08:00
196 changed files with 4070 additions and 2133 deletions

View File

@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
@ -121,7 +121,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
stream=stream
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
@ -141,10 +141,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"""
if not node_id:
raise ValueError('node_id is required')
if args.get('inputs') is None:
raise ValueError('inputs is required')
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
@ -191,7 +191,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
:param workflow: Workflow
:param user: account or end user
:param invoke_from: invoke from source
@ -232,8 +232,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'user': user,
'context': contextvars.copy_context()
'context': contextvars.copy_context(),
})
worker_thread.start()
@ -246,7 +245,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)
return AdvancedChatAppGenerateResponseConverter.convert(
@ -259,7 +258,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
user: Account,
context: contextvars.Context) -> None:
"""
Generate worker in a new thread.
@ -307,14 +305,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
finally:
db.session.close()
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False) \
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def _handle_advanced_chat_response(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
@ -334,7 +335,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)
try:

View File

@ -3,9 +3,6 @@ import os
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
@ -94,7 +91,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
app_record=app_record,

View File

@ -44,7 +44,7 @@ from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.enums import SystemVariable
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from events.message_event import message_was_created
from extensions.ext_database import db
@ -69,14 +69,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow_system_variables: dict[SystemVariable, Any]
def __init__(
self,
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
stream: bool,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
SystemVariable.USER_ID: user_id,
}
self._task_state = WorkflowTaskState()
@ -127,7 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
return self._to_stream_response(generator)
else:
@ -239,7 +239,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# init fake graph runtime state
graph_runtime_state = None
workflow_run = None
for queue_message in self._queue_manager.listen():
event = queue_message.event
@ -270,9 +270,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
workflow_run=workflow_run,
event=event
)
@ -307,7 +307,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -316,7 +316,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -325,7 +325,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -334,10 +334,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
@ -360,10 +360,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
@ -400,7 +400,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
# Save message
self._save_message(graph_runtime_state=graph_runtime_state)
@ -413,7 +413,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.close()
elif isinstance(event, QueueAnnotationReplyEvent):
@ -421,7 +421,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.close()
elif isinstance(event, QueueTextChunkEvent):
@ -446,7 +446,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
@ -458,11 +458,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_end_to_stream_response()
else:
continue
# publish None when task finished
if tts_publisher:
tts_publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
@ -507,7 +507,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']

View File

@ -1,6 +1,6 @@
import time
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list[FileVar],
files: list["FileVar"],
query: Optional[str] = None) -> int:
"""
Get pre calculate rest tokens
@ -126,7 +128,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list[FileVar],
files: list["FileVar"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
@ -366,7 +368,7 @@ class AppRunner:
message_id=message_id,
trace_manager=app_generate_entity.trace_manager
)
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage]) -> bool:
@ -418,7 +420,7 @@ class AppRunner:
inputs=inputs,
query=query
)
def query_app_annotations_to_reply(self, app_record: App,
message: Message,
query: str,

View File

@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return introduction
def _get_conversation(self, conversation_id: str) -> Conversation:
def _get_conversation(self, conversation_id: str):
"""
Get conversation by conversation id
:param conversation_id: conversation id
@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
.first()
)
if not conversation:
raise ConversationNotExistsError()
return conversation
def _get_message(self, message_id: str) -> Message:

View File

@ -11,7 +11,8 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.enums import SystemVariable
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db

View File

@ -41,7 +41,9 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@ -179,7 +181,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
@ -246,7 +248,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
raise Exception('Workflow run not initialized.')
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
workflow_run=workflow_run,
event=event
)
@ -281,7 +283,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -290,7 +292,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -299,7 +301,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
@ -308,10 +310,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
@ -332,10 +334,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,