mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
refactor workflow generate pipeline
This commit is contained in:
164
api/core/app/apps/workflow/app_generator.py
Normal file
164
api/core/app/apps/workflow/app_generator.py
Normal file
@ -0,0 +1,164 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
def generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
inputs = args['inputs']
|
||||
|
||||
# parse files
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict)
|
||||
if file_upload_entity:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_upload_entity,
|
||||
user
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs=self._get_cleaned_inputs(inputs, app_config),
|
||||
files=file_objs,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=app_model.mode
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
return self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:return:
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool = False) -> Union[dict, Generator]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param stream: is stream
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise ConversationTaskStoppedException()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
finally:
|
||||
db.session.remove()
|
||||
23
api/core/app/apps/workflow/app_queue_manager.py
Normal file
23
api/core/app/apps/workflow/app_queue_manager.py
Normal file
@ -0,0 +1,23 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueMessage,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
|
||||
return QueueMessage(
|
||||
task_id=self._task_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
156
api/core/app/apps/workflow/app_runner.py
Normal file
156
api/core/app/apps/workflow/app_runner.py
Normal file
@ -0,0 +1,156 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AppGenerateEntity,
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
|
||||
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppRunner:
|
||||
"""
|
||||
Workflow Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
files = application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs
|
||||
):
|
||||
return
|
||||
|
||||
# fetch user
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]:
|
||||
user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first()
|
||||
else:
|
||||
user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=user,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.FILES: files
|
||||
},
|
||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
|
||||
)
|
||||
|
||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: WorkflowAppGenerateEntity,
|
||||
inputs: dict) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param queue_manager: application queue manager
|
||||
:param app_record: app record
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
moderation_feature = InputModeration()
|
||||
_, inputs, query = moderation_feature.check(
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_generate_entity.app_config.tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=''
|
||||
)
|
||||
except ModerationException as e:
|
||||
if app_generate_entity.stream:
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=str(e),
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _stream_output(self, queue_manager: AppQueueManager,
|
||||
text: str) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=token
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: dict) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_generate_entity: app generate entity
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
moderation_feature = InputModeration()
|
||||
return moderation_feature.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=''
|
||||
)
|
||||
408
api/core/app/apps/workflow/generate_task_pipeline.py
Normal file
408
api/core/app/apps/workflow/generate_task_pipeline.py
Normal file
@ -0,0 +1,408 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFinishedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
answer: str = ""
|
||||
metadata: dict = {}
|
||||
workflow_run_id: Optional[str] = None
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._task_state = TaskState()
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
|
||||
def process(self) -> Union[dict, Generator]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
if self._stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
return self._process_blocking_response()
|
||||
|
||||
def _process_blocking_response(self) -> dict:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
|
||||
else:
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
else:
|
||||
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.answer = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.answer,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
response = {
|
||||
'event': 'workflow_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
'status': workflow_run.status,
|
||||
'outputs': workflow_run.outputs_dict,
|
||||
'error': workflow_run.error,
|
||||
'elapsed_time': workflow_run.elapsed_time,
|
||||
'total_tokens': workflow_run.total_tokens,
|
||||
'total_steps': workflow_run.total_steps,
|
||||
'created_at': int(workflow_run.created_at.timestamp()),
|
||||
'finished_at': int(workflow_run.finished_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
def _process_stream_response(self) -> Generator:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
self._task_state.workflow_run_id = event.workflow_run_id
|
||||
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
response = {
|
||||
'event': 'workflow_started',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
'created_at': int(workflow_run.created_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
response = {
|
||||
'event': 'node_started',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': workflow_node_execution.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_node_execution.id,
|
||||
'node_id': workflow_node_execution.node_id,
|
||||
'index': workflow_node_execution.index,
|
||||
'predecessor_node_id': workflow_node_execution.predecessor_node_id,
|
||||
'inputs': workflow_node_execution.inputs_dict,
|
||||
'created_at': int(workflow_node_execution.created_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeFinishedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
response = {
|
||||
'event': 'node_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': workflow_node_execution.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_node_execution.id,
|
||||
'node_id': workflow_node_execution.node_id,
|
||||
'index': workflow_node_execution.index,
|
||||
'predecessor_node_id': workflow_node_execution.predecessor_node_id,
|
||||
'inputs': workflow_node_execution.inputs_dict,
|
||||
'process_data': workflow_node_execution.process_data_dict,
|
||||
'outputs': workflow_node_execution.outputs_dict,
|
||||
'status': workflow_node_execution.status,
|
||||
'error': workflow_node_execution.error,
|
||||
'elapsed_time': workflow_node_execution.elapsed_time,
|
||||
'execution_metadata': workflow_node_execution.execution_metadata_dict,
|
||||
'created_at': int(workflow_node_execution.created_at.timestamp()),
|
||||
'finished_at': int(workflow_node_execution.finished_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
|
||||
else:
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
else:
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
data = self._error_to_stream_response_data(self._handle_error(err_event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.answer = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.answer,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
|
||||
replace_response = {
|
||||
'event': 'text_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
'data': {
|
||||
'text': self._task_state.answer
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(replace_response)
|
||||
|
||||
workflow_run_response = {
|
||||
'event': 'workflow_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
'status': workflow_run.status,
|
||||
'outputs': workflow_run.outputs_dict,
|
||||
'error': workflow_run.error,
|
||||
'elapsed_time': workflow_run.elapsed_time,
|
||||
'total_tokens': workflow_run.total_tokens,
|
||||
'total_steps': workflow_run.total_steps,
|
||||
'created_at': int(workflow_run.created_at.timestamp()),
|
||||
'finished_at': int(workflow_run.finished_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(workflow_run_response)
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.chunk_text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=self._task_state.answer
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
continue
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(delta_text)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
response = self._handle_chunk(delta_text)
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
response = {
|
||||
'event': 'text_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
'data': {
|
||||
'text': event.text
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield "event: ping\n\n"
|
||||
else:
|
||||
continue
|
||||
|
||||
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Get workflow run.
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
|
||||
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Get workflow node execution.
|
||||
:param workflow_node_execution_id: workflow node execution id
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()
|
||||
|
||||
def _handle_chunk(self, text: str) -> dict:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'text_chunk',
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'data': {
|
||||
'text': text
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent) -> Exception:
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
return InvokeAuthorizationError('Incorrect API key provided')
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
return e
|
||||
else:
|
||||
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
def _error_to_stream_response_data(self, e: Exception) -> dict:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
error_responses = {
|
||||
ValueError: {'code': 'invalid_param', 'status': 400},
|
||||
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
|
||||
QuotaExceededError: {
|
||||
'code': 'provider_quota_exceeded',
|
||||
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
'status': 400
|
||||
},
|
||||
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
|
||||
InvokeError: {'code': 'completion_request_error', 'status': 400}
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
if data:
|
||||
data.setdefault('message', getattr(e, 'description', str(e)))
|
||||
else:
|
||||
logging.error(e)
|
||||
data = {
|
||||
'code': 'internal_server_error',
|
||||
'message': 'Internal Server Error, please contact support.',
|
||||
'status': 500
|
||||
}
|
||||
|
||||
return {
|
||||
'event': 'error',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
**data
|
||||
}
|
||||
|
||||
def _yield_response(self, response: dict) -> str:
|
||||
"""
|
||||
Yield response.
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return "data: " + json.dumps(response) + "\n\n"
|
||||
|
||||
def _init_output_moderation(self) -> Optional[OutputModeration]:
|
||||
"""
|
||||
Init output moderation.
|
||||
:return:
|
||||
"""
|
||||
app_config = self._application_generate_entity.app_config
|
||||
sensitive_word_avoidance = app_config.sensitive_word_avoidance
|
||||
|
||||
if sensitive_word_avoidance:
|
||||
return OutputModeration(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
queue_manager=self._queue_manager
|
||||
)
|
||||
Reference in New Issue
Block a user