mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 16:08:04 +08:00
Merge branch 'fix/chore-fix' into dev/plugin-deploy
This commit is contained in:
@ -119,7 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
usage_dict = {}
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
|
||||
@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from extensions.ext_database import db
|
||||
@ -346,7 +346,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@ -68,24 +68,17 @@ from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
|
||||
class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
_conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
@ -97,7 +90,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream: bool,
|
||||
dialogue_count: int,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
@ -114,33 +107,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
|
||||
self._message_id = message.id
|
||||
self._message_created_at = int(message.created_at.timestamp())
|
||||
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
}
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self._wip_workflow_agent_logs = {}
|
||||
self._message_cycle_manager = MessageCycleManage(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
self._message_id = message.id
|
||||
self._message_created_at = int(message.created_at.timestamp())
|
||||
self._conversation_name_generate_thread: Thread | None = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self._workflow_run_id = ""
|
||||
self._workflow_run_id: str = ""
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@ -148,13 +143,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:return:
|
||||
"""
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
if self._stream:
|
||||
if self._base_task_pipeline._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
@ -273,24 +268,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
with Session(db.engine) as session:
|
||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
err = self._base_task_pipeline._handle_error(
|
||||
event=event, session=session, message_id=self._message_id
|
||||
)
|
||||
session.commit()
|
||||
yield self._error_to_stream_response(err)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start(
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
user_id=self._user_id,
|
||||
@ -301,7 +298,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
message.workflow_run_id = workflow_run.id
|
||||
workflow_start_resp = self._workflow_start_to_stream_response(
|
||||
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@ -314,12 +311,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_node_retry_to_stream_response(
|
||||
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -333,13 +332,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
node_start_resp = self._workflow_node_start_to_stream_response(
|
||||
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -352,12 +353,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||
self._recorded_files.extend(
|
||||
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||
session=session, event=event
|
||||
)
|
||||
|
||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -368,32 +373,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session, event=event
|
||||
)
|
||||
|
||||
response_finish = self._workflow_node_finish_to_stream_response(
|
||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response_finish:
|
||||
yield response_finish
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_start_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
@ -401,13 +409,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_finish_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
@ -415,9 +427,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -429,9 +443,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -443,9 +459,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -460,8 +478,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -472,21 +490,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -497,21 +517,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -523,20 +545,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
||||
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
err = self._base_task_pipeline._handle_error(
|
||||
event=err_event, session=session, message_id=self._message_id
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
yield self._error_to_stream_response(err)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent):
|
||||
if self._workflow_run_id and graph_runtime_state:
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -547,7 +571,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -561,18 +585,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
self._message_cycle_manager._handle_retriever_resources(event)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
session.commit()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._handle_annotation_reply(event)
|
||||
self._message_cycle_manager._handle_annotation_reply(event)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
@ -593,29 +617,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(
|
||||
yield self._message_cycle_manager._message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
|
||||
self._task_state.answer
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
# Save message
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||
session.commit()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
@ -629,7 +659,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
@ -693,20 +723,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
if self._base_task_pipeline._output_moderation_handler:
|
||||
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(text)
|
||||
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -251,7 +251,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -15,7 +15,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
@classmethod
|
||||
def convert(
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> dict[str, Any] | Generator[str | dict[str, Any], Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
|
||||
@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -37,7 +37,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
) -> Generator[str | Mapping[str, Any], None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -57,7 +57,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -66,7 +66,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
@ -231,7 +231,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[Mapping, Generator[str, None, None]]:
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -235,6 +235,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
@ -286,7 +287,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -59,7 +59,6 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
@ -67,16 +66,11 @@ from models.workflow import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
@ -85,7 +79,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
@ -102,17 +96,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._workflow_run_id = ""
|
||||
|
||||
@ -122,7 +119,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
:return:
|
||||
"""
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
if self._base_task_pipeline._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
@ -239,29 +236,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
"""
|
||||
graph_runtime_state = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event=event)
|
||||
yield self._error_to_stream_response(err)
|
||||
err = self._base_task_pipeline._handle_error(event=event)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start(
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
user_id=self._user_id,
|
||||
created_by_role=self._created_by_role,
|
||||
)
|
||||
self._workflow_run_id = workflow_run.id
|
||||
start_resp = self._workflow_start_to_stream_response(
|
||||
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@ -273,12 +270,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -292,12 +291,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
node_start_response = self._workflow_node_start_to_stream_response(
|
||||
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -308,9 +309,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if node_start_response:
|
||||
yield node_start_response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||
session=session, event=event
|
||||
)
|
||||
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -321,12 +324,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session,
|
||||
event=event,
|
||||
)
|
||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
||||
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -341,13 +344,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_start_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
@ -356,13 +363,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_finish_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
@ -371,9 +382,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -386,9 +399,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -401,9 +416,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -418,8 +435,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -433,7 +450,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -447,8 +464,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -463,7 +480,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@ -475,8 +492,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -494,7 +511,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@ -514,7 +531,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
workflow_run_id: Optional[str] = None
|
||||
workflow_run_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@ -329,6 +329,7 @@ class QueueAgentLogEvent(AppQueueEvent):
|
||||
error: str | None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
|
||||
@ -716,6 +716,7 @@ class AgentLogStreamResponse(StreamResponse):
|
||||
error: str | None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||
data: Data
|
||||
|
||||
@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
PingStreamResponse,
|
||||
TaskState,
|
||||
)
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
|
||||
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: TaskState
|
||||
_application_generate_entity: AppGenerateEntity
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param user: user
|
||||
:param stream: stream
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._start_at = time.perf_counter()
|
||||
|
||||
@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class MessageCycleManage:
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
|
||||
@ -36,7 +36,6 @@ from core.app.entities.task_entities import (
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -60,13 +59,20 @@ from models.workflow import (
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
|
||||
from .exc import WorkflowRunNotFoundError
|
||||
|
||||
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||
) -> None:
|
||||
self._workflow_run: WorkflowRun | None = None
|
||||
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
|
||||
def _handle_workflow_run_start(
|
||||
self,
|
||||
@ -104,7 +110,8 @@ class WorkflowCycleManage:
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
|
||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
|
||||
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = workflow_run_id
|
||||
@ -241,7 +248,7 @@ class WorkflowCycleManage:
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
stmt = select(WorkflowNodeExecution.node_execution_id).where(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
@ -249,15 +256,18 @@ class WorkflowCycleManage:
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
)
|
||||
|
||||
running_workflow_node_executions = session.scalars(stmt).all()
|
||||
ids = session.scalars(stmt).all()
|
||||
# Use self._get_workflow_node_execution here to make sure the cache is updated
|
||||
running_workflow_node_executions = [
|
||||
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
|
||||
]
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
finish_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = finish_at
|
||||
workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds()
|
||||
workflow_node_execution.finished_at = now
|
||||
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@ -299,6 +309,8 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(
|
||||
@ -325,6 +337,7 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
|
||||
workflow_node_execution = session.merge(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(
|
||||
@ -364,6 +377,7 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
|
||||
workflow_node_execution = session.merge(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
@ -415,6 +429,8 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
@ -811,25 +827,23 @@ class WorkflowCycleManage:
|
||||
return None
|
||||
|
||||
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
if self._workflow_run and self._workflow_run.id == workflow_run_id:
|
||||
cached_workflow_run = self._workflow_run
|
||||
cached_workflow_run = session.merge(cached_workflow_run)
|
||||
return cached_workflow_run
|
||||
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFoundError(workflow_run_id)
|
||||
self._workflow_run = workflow_run
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
|
||||
workflow_node_execution = session.scalar(stmt)
|
||||
if not workflow_node_execution:
|
||||
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
|
||||
|
||||
return workflow_node_execution
|
||||
if node_execution_id not in self._workflow_node_executions:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||
return cached_workflow_node_execution
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
@ -848,5 +862,6 @@ class WorkflowCycleManage:
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
),
|
||||
)
|
||||
|
||||
@ -178,7 +178,7 @@ class ModelInstance:
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> list[int]:
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for llm
|
||||
|
||||
@ -191,7 +191,7 @@ class ModelInstance:
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
list[int],
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
|
||||
@ -1,9 +1,47 @@
|
||||
import tiktoken
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
_tokenizer: Any = None
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_gpt2(text: str) -> int:
|
||||
"""
|
||||
use gpt2 tokenizer to get num tokens
|
||||
"""
|
||||
_tokenizer = GPT2Tokenizer.get_encoder()
|
||||
tokens = _tokenizer.encode(text)
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
encoding = tiktoken.encoding_for_model("gpt2")
|
||||
tiktoken_vec = encoding.encode(text)
|
||||
return len(tiktoken_vec)
|
||||
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
|
||||
#
|
||||
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||
# result = future.result()
|
||||
# return cast(int, result)
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
|
||||
@staticmethod
|
||||
def get_encoder() -> Any:
|
||||
global _tokenizer, _lock
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
# Try to use tiktoken to get the tokenizer because it is faster
|
||||
#
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
_tokenizer = tiktoken.get_encoding("gpt2")
|
||||
except Exception:
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
|
||||
return _tokenizer
|
||||
|
||||
@ -119,7 +119,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
):
|
||||
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||
"""
|
||||
invoke workflow app
|
||||
"""
|
||||
@ -146,7 +146,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
):
|
||||
) -> Generator[Mapping | str, None, None] | Mapping:
|
||||
"""
|
||||
invoke completion app
|
||||
"""
|
||||
|
||||
@ -268,7 +268,7 @@ Here is the extra instruction you need to follow:
|
||||
return summary.message.content
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
new_lines: list[str] = []
|
||||
# split long line into multiple lines
|
||||
for i in range(len(lines)):
|
||||
line = lines[i]
|
||||
@ -286,16 +286,16 @@ Here is the extra instruction you need to follow:
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for i in new_lines:
|
||||
for i in new_lines: # type: ignore
|
||||
if len(messages) == 0:
|
||||
messages.append(i)
|
||||
messages.append(i) # type: ignore
|
||||
else:
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5:
|
||||
messages[-1] += i
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
|
||||
messages.append(i)
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
|
||||
messages[-1] += i # type: ignore
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
|
||||
messages.append(i) # type: ignore
|
||||
else:
|
||||
messages[-1] += i
|
||||
messages[-1] += i # type: ignore
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
|
||||
@ -103,7 +103,7 @@ class BasePluginManager:
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
yield type(**json.loads(line))
|
||||
yield type(**json.loads(line)) # type: ignore
|
||||
|
||||
def _request_with_model(
|
||||
self,
|
||||
|
||||
@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
|
||||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
quoted_ids = [f"'{id}'" for id in ids]
|
||||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||
|
||||
|
||||
@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(ids=ids)
|
||||
|
||||
|
||||
@ -0,0 +1,104 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchConfig,
|
||||
ElasticSearchVector,
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticSearchJaVector(ElasticSearchVector):
|
||||
def create_collection(
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
|
||||
if not self._client.indices.exists(index=self._collection_name):
|
||||
dim = len(embeddings[0])
|
||||
settings = {
|
||||
"analysis": {
|
||||
"analyzer": {
|
||||
"ja_analyzer": {
|
||||
"type": "custom",
|
||||
"char_filter": [
|
||||
"icu_normalizer",
|
||||
"kuromoji_iteration_mark",
|
||||
],
|
||||
"tokenizer": "kuromoji_tokenizer",
|
||||
"filter": [
|
||||
"kuromoji_baseform",
|
||||
"kuromoji_part_of_speech",
|
||||
"ja_stop",
|
||||
"kuromoji_number",
|
||||
"kuromoji_stemmer",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {
|
||||
"type": "text",
|
||||
"analyzer": "ja_analyzer",
|
||||
"search_analyzer": "ja_analyzer",
|
||||
},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return ElasticSearchJaVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||
port=config.get("ELASTICSEARCH_PORT", 9200),
|
||||
username=config.get("ELASTICSEARCH_USERNAME", ""),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||
),
|
||||
attributes=[],
|
||||
)
|
||||
@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
|
||||
return bool(self._client.exists(index=self._collection_name, id=id))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
for id in ids:
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@ class Field(Enum):
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
VECTOR = "vector"
|
||||
# Sparse Vector aims to support full text search
|
||||
SPARSE_VECTOR = "sparse_vector"
|
||||
TEXT_KEY = "text"
|
||||
PRIMARY_KEY = "id"
|
||||
DOC_ID = "metadata.doc_id"
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException # type: ignore
|
||||
from pymilvus.milvus_client import IndexParams # type: ignore
|
||||
@ -20,16 +21,25 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
uri: str
|
||||
token: Optional[str] = None
|
||||
user: str
|
||||
password: str
|
||||
batch_size: int = 100
|
||||
database: str = "default"
|
||||
"""
|
||||
Configuration class for Milvus connection.
|
||||
"""
|
||||
|
||||
uri: str # Milvus server URI
|
||||
token: Optional[str] = None # Optional token for authentication
|
||||
user: str # Username for authentication
|
||||
password: str # Password for authentication
|
||||
batch_size: int = 100 # Batch size for operations
|
||||
database: str = "default" # Database name
|
||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""
|
||||
Validate the configuration values.
|
||||
Raises ValueError if required fields are missing.
|
||||
"""
|
||||
if not values.get("uri"):
|
||||
raise ValueError("config MILVUS_URI is required")
|
||||
if not values.get("user"):
|
||||
@ -39,6 +49,9 @@ class MilvusConfig(BaseModel):
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
"""
|
||||
Convert the configuration to a dictionary of Milvus connection parameters.
|
||||
"""
|
||||
return {
|
||||
"uri": self.uri,
|
||||
"token": self.token,
|
||||
@ -49,26 +62,57 @@ class MilvusConfig(BaseModel):
|
||||
|
||||
|
||||
class MilvusVector(BaseVector):
|
||||
"""
|
||||
Milvus vector storage implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, collection_name: str, config: MilvusConfig):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._consistency_level = "Session"
|
||||
self._fields: list[str] = []
|
||||
self._consistency_level = "Session" # Consistency level for Milvus operations
|
||||
self._fields: list[str] = [] # List of fields in the collection
|
||||
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
"""
|
||||
Check if the current Milvus version supports hybrid search.
|
||||
Returns True if the version is >= 2.5.0, otherwise False.
|
||||
"""
|
||||
if not self._client_config.enable_hybrid_search:
|
||||
return False
|
||||
|
||||
try:
|
||||
milvus_version = self._client.get_server_version()
|
||||
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
|
||||
return False
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""
|
||||
Get the type of vector storage (Milvus).
|
||||
"""
|
||||
return VectorType.MILVUS
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
Create a collection and add texts with embeddings.
|
||||
"""
|
||||
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""
|
||||
Add texts and their embeddings to the collection.
|
||||
"""
|
||||
insert_dict_list = []
|
||||
for i in range(len(documents)):
|
||||
insert_dict = {
|
||||
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
|
||||
# function will automatically convert the native text into a sparse vector for us.
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
@ -76,12 +120,11 @@ class MilvusVector(BaseVector):
|
||||
insert_dict_list.append(insert_dict)
|
||||
# Total insert count
|
||||
total_count = len(insert_dict_list)
|
||||
|
||||
pks: list[str] = []
|
||||
|
||||
for i in range(0, total_count, 1000):
|
||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||
# Insert into the collection.
|
||||
batch_insert_list = insert_dict_list[i : i + 1000]
|
||||
try:
|
||||
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
|
||||
pks.extend(ids)
|
||||
@ -91,6 +134,9 @@ class MilvusVector(BaseVector):
|
||||
return pks
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
"""
|
||||
Get document IDs by metadata field key and value.
|
||||
"""
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
|
||||
)
|
||||
@ -100,12 +146,18 @@ class MilvusVector(BaseVector):
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
"""
|
||||
Delete documents by metadata field key and value.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
"""
|
||||
Delete documents by their IDs.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
result = self._client.query(
|
||||
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
|
||||
@ -115,10 +167,16 @@ class MilvusVector(BaseVector):
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete the entire collection.
|
||||
"""
|
||||
if self._client.has_collection(self._collection_name):
|
||||
self._client.drop_collection(self._collection_name, None)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""
|
||||
Check if a text with the given ID exists in the collection.
|
||||
"""
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
return False
|
||||
|
||||
@ -128,32 +186,80 @@ class MilvusVector(BaseVector):
|
||||
|
||||
return len(result) > 0
|
||||
|
||||
def field_exists(self, field: str) -> bool:
|
||||
"""
|
||||
Check if a field exists in the collection.
|
||||
"""
|
||||
return field in self._fields
|
||||
|
||||
def _process_search_results(
|
||||
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Common method to process search results
|
||||
|
||||
:param results: Search results
|
||||
:param output_fields: Fields to be output
|
||||
:param score_threshold: Score threshold for filtering
|
||||
:return: List of documents
|
||||
"""
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result["entity"].get(output_fields[1], {})
|
||||
metadata["score"] = result["distance"]
|
||||
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
# Set search parameters.
|
||||
"""
|
||||
Search for documents by vector similarity.
|
||||
"""
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
anns_field=Field.VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results[0]:
|
||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||
metadata["score"] = result["distance"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
||||
"""
|
||||
Search for documents by full-text search (if hybrid search is enabled).
|
||||
"""
|
||||
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
|
||||
return []
|
||||
|
||||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query],
|
||||
anns_field=Field.SPARSE_VECTOR.value,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Create a new collection in Milvus with the specified schema and index parameters.
|
||||
"""
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
@ -161,7 +267,7 @@ class MilvusVector(BaseVector):
|
||||
return
|
||||
# Grab the existing collection if it exists
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
|
||||
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
|
||||
|
||||
# Determine embedding dim
|
||||
@ -170,16 +276,36 @@ class MilvusVector(BaseVector):
|
||||
if metadatas:
|
||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field
|
||||
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
|
||||
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
||||
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
Field.CONTENT_KEY.value,
|
||||
DataType.VARCHAR,
|
||||
max_length=65_535,
|
||||
enable_analyzer=self._hybrid_search_enabled,
|
||||
)
|
||||
)
|
||||
# Create the primary key field
|
||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
# Create custom function to support text to sparse vector by BM25
|
||||
if self._hybrid_search_enabled:
|
||||
bm25_function = Function(
|
||||
name="text_bm25_emb",
|
||||
input_field_names=[Field.CONTENT_KEY.value],
|
||||
output_field_names=[Field.SPARSE_VECTOR.value],
|
||||
function_type=FunctionType.BM25,
|
||||
)
|
||||
schema.add_function(bm25_function)
|
||||
|
||||
for x in schema.fields:
|
||||
self._fields.append(x.name)
|
||||
# Since primary field is auto-id, no need to track it
|
||||
@ -189,10 +315,15 @@ class MilvusVector(BaseVector):
|
||||
index_params_obj = IndexParams()
|
||||
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
||||
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
index_params_obj.add_index(
|
||||
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
|
||||
)
|
||||
|
||||
# Create the collection
|
||||
collection_name = self._collection_name
|
||||
self._client.create_collection(
|
||||
collection_name=collection_name,
|
||||
collection_name=self._collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params_obj,
|
||||
consistency_level=self._consistency_level,
|
||||
@ -200,12 +331,22 @@ class MilvusVector(BaseVector):
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
"""
|
||||
Initialize and return a Milvus client.
|
||||
"""
|
||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||
return client
|
||||
|
||||
|
||||
class MilvusVectorFactory(AbstractVectorFactory):
|
||||
"""
|
||||
Factory class for creating MilvusVector instances.
|
||||
"""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||
"""
|
||||
Initialize a MilvusVector instance for the given dataset.
|
||||
"""
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||
user=dify_config.MILVUS_USER or "",
|
||||
password=dify_config.MILVUS_PASSWORD or "",
|
||||
database=dify_config.MILVUS_DATABASE or "",
|
||||
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
|
||||
),
|
||||
)
|
||||
|
||||
@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
|
||||
return results.row_count > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
)
|
||||
|
||||
@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
|
||||
return bool(cur.rowcount != 0)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
|
||||
@ -167,6 +167,8 @@ class OracleVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
|
||||
|
||||
@ -129,6 +129,11 @@ class PGVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
|
||||
# Scenario 1: extract a document fails, resulting in a table not being created.
|
||||
# Then clicking the retry button triggers a delete operation on an empty list.
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
|
||||
|
||||
@ -140,6 +140,8 @@ class TencentVector(BaseVector):
|
||||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._db.collection(self._collection_name).delete(document_ids=ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
|
||||
@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
if not tidb_auth_binding:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
|
||||
else:
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID or "",
|
||||
@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
db.session.add(new_tidb_auth_binding)
|
||||
db.session.commit()
|
||||
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
|
||||
@ -90,6 +90,12 @@ class Vector:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
return ElasticSearchVectorFactory
|
||||
case VectorType.ELASTICSEARCH_JA:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
|
||||
ElasticSearchJaVectorFactory,
|
||||
)
|
||||
|
||||
return ElasticSearchJaVectorFactory
|
||||
case VectorType.TIDB_VECTOR:
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ class VectorType(StrEnum):
|
||||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
ELASTICSEARCH_JA = "elasticsearch-ja"
|
||||
LINDORM = "lindorm"
|
||||
COUCHBASE = "couchbase"
|
||||
BAIDU = "baidu"
|
||||
|
||||
@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor):
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
plaintext_file_key = ""
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor):
|
||||
text = "\n\n".join(text_list)
|
||||
|
||||
# save plaintext file for caching
|
||||
if not plaintext_file_exists and plaintext_file_key:
|
||||
storage.save(plaintext_file_key, text.encode("utf-8"))
|
||||
if not plaintext_file_exists and self._file_cache_key:
|
||||
storage.save(self._file_cache_key, text.encode("utf-8"))
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -80,6 +81,10 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_nodes = self._split_child_nodes(
|
||||
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
|
||||
)
|
||||
if kwargs.get("preview"):
|
||||
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
|
||||
child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER]
|
||||
|
||||
document.children = child_nodes
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
|
||||
@ -54,7 +54,12 @@ class ASRTool(BuiltinTool):
|
||||
items.append((provider, model.model))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
|
||||
@ -62,7 +62,12 @@ class TTSTool(BuiltinTool):
|
||||
items.append((provider, model.model, voices))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
|
||||
@ -212,8 +212,23 @@ class ApiTool(Tool):
|
||||
else:
|
||||
body = body
|
||||
|
||||
if method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
response: httpx.Response = getattr(ssrf_proxy, method)(
|
||||
if method in {
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
}:
|
||||
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
|
||||
@ -147,7 +147,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
|
||||
@field_validator("variable_name", mode="before")
|
||||
@classmethod
|
||||
def transform_variable_name(cls, value) -> str:
|
||||
def transform_variable_name(cls, value: str) -> str:
|
||||
"""
|
||||
The variable name must be a string.
|
||||
"""
|
||||
@ -167,6 +167,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
error: Optional[str] = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
|
||||
@ -203,6 +203,7 @@ class AgentLogEvent(BaseAgentEvent):
|
||||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent
|
||||
|
||||
@ -89,7 +89,11 @@ class AgentNode(ToolNode):
|
||||
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(message_stream, {}, parameters_for_log)
|
||||
yield from self._transform_message(
|
||||
message_stream,
|
||||
{"provider": (cast(AgentNodeData, self.node_data)).agent_strategy_provider_name},
|
||||
parameters_for_log,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
@ -170,7 +174,12 @@ class AgentNode(ToolNode):
|
||||
extra.get("descrption", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
|
||||
tool_value.append(tool_runtime.entity.model_dump(mode="json"))
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": tool_runtime.runtime.runtime_parameters,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == "model-selector":
|
||||
value = cast(dict[str, Any], value)
|
||||
|
||||
@ -2,14 +2,18 @@ import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
import tempfile
|
||||
from typing import cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypdfium2 # type: ignore
|
||||
import yaml # type: ignore
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
@ -78,6 +82,23 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: DocumentExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {node_id + ".files": node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
"""Extract text from a file based on its MIME type."""
|
||||
@ -189,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
text = []
|
||||
# Process paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text.append(paragraph.text)
|
||||
|
||||
# Process tables
|
||||
for table in doc.tables:
|
||||
# Table header
|
||||
try:
|
||||
# table maybe cause errors so ignore it.
|
||||
if len(table.rows) > 0 and table.rows[0].cells is not None:
|
||||
# Keep track of paragraph and table positions
|
||||
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
||||
|
||||
# Process paragraphs and tables
|
||||
for i, paragraph in enumerate(doc.paragraphs):
|
||||
if paragraph.text.strip():
|
||||
content_items.append((i, "paragraph", paragraph))
|
||||
|
||||
for i, table in enumerate(doc.tables):
|
||||
content_items.append((i, "table", table))
|
||||
|
||||
# Sort content items based on their original position
|
||||
content_items.sort(key=operator.itemgetter(0))
|
||||
|
||||
# Process sorted content
|
||||
for _, item_type, item in content_items:
|
||||
if item_type == "paragraph":
|
||||
if isinstance(item, Table):
|
||||
continue
|
||||
text.append(item.text)
|
||||
elif item_type == "table":
|
||||
# Process tables
|
||||
if not isinstance(item, Table):
|
||||
continue
|
||||
try:
|
||||
# Check if any cell in the table has text
|
||||
has_content = False
|
||||
for row in table.rows:
|
||||
for row in item.rows:
|
||||
if any(cell.text.strip() for cell in row.cells):
|
||||
has_content = True
|
||||
break
|
||||
|
||||
if has_content:
|
||||
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
|
||||
for row in table.rows[1:]:
|
||||
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
|
||||
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
|
||||
markdown_table = f"| {' | '.join(cell_texts)} |\n"
|
||||
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"
|
||||
|
||||
for row in item.rows[1:]:
|
||||
# Replace newlines with <br> in each cell
|
||||
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
|
||||
markdown_table += "| " + " | ".join(row_cells) + " |\n"
|
||||
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
|
||||
|
||||
@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
method: Literal["get", "post", "put", "patch", "delete", "head"]
|
||||
method: Literal[
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"patch",
|
||||
"delete",
|
||||
"head",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
]
|
||||
url: str
|
||||
authorization: HttpRequestNodeAuthorization
|
||||
headers: str
|
||||
|
||||
@ -37,7 +37,22 @@ BODY_TYPE_TO_CONTENT_TYPE = {
|
||||
|
||||
|
||||
class Executor:
|
||||
method: Literal["get", "head", "post", "put", "delete", "patch"]
|
||||
method: Literal[
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
]
|
||||
url: str
|
||||
params: list[tuple[str, str]] | None
|
||||
content: str | bytes | None
|
||||
@ -67,12 +82,6 @@ class Executor:
|
||||
node_data.authorization.config.api_key
|
||||
).text
|
||||
|
||||
# check if node_data.url is a valid URL
|
||||
if not node_data.url:
|
||||
raise InvalidURLError("url is required")
|
||||
if not node_data.url.startswith(("http://", "https://")):
|
||||
raise InvalidURLError("url should start with http:// or https://")
|
||||
|
||||
self.url: str = node_data.url
|
||||
self.method = node_data.method
|
||||
self.auth = node_data.authorization
|
||||
@ -99,6 +108,12 @@ class Executor:
|
||||
def _init_url(self):
|
||||
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
||||
|
||||
# check if url is a valid URL
|
||||
if not self.url:
|
||||
raise InvalidURLError("url is required")
|
||||
if not self.url.startswith(("http://", "https://")):
|
||||
raise InvalidURLError("url should start with http:// or https://")
|
||||
|
||||
def _init_params(self):
|
||||
"""
|
||||
Almost same as _init_headers(), difference:
|
||||
@ -158,7 +173,10 @@ class Executor:
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("json body type should have exactly one item")
|
||||
json_string = self.variable_pool.convert_template(data[0].value).text
|
||||
json_object = json.loads(json_string, strict=False)
|
||||
try:
|
||||
json_object = json.loads(json_string, strict=False)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
|
||||
self.json = json_object
|
||||
# self.json = self._parse_object_contains_variables(json_object)
|
||||
case "binary":
|
||||
@ -246,7 +264,22 @@ class Executor:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
if self.method not in {
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
}:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
@ -263,7 +296,7 @@ class Executor:
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e))
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -197,6 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {}
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
@ -264,6 +265,11 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||
agent_execution_metadata = {
|
||||
key: value for key, value in msg_metadata.items() if key in NodeRunMetadataKey
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
@ -299,6 +305,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
@ -309,6 +316,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
@ -319,7 +327,11 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": files, "json": json, **variables},
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info, NodeRunMetadataKey.AGENT_LOG: agent_logs},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||
NodeRunMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@ -340,6 +340,10 @@ class WorkflowEntry:
|
||||
):
|
||||
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
|
||||
|
||||
# environment variable already exist in variable pool, not from user inputs
|
||||
if variable_pool.get(variable_selector):
|
||||
continue
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
|
||||
Reference in New Issue
Block a user