chore(workflow): max thread submit count

This commit is contained in:
takatost
2024-09-02 20:20:32 +08:00
parent 5ca9df65de
commit 955884b87e
8 changed files with 86 additions and 17 deletions

View File

@ -1,12 +1,12 @@
import logging
import queue
import time
import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Optional
from flask import Flask, current_app
from uritemplate.variable import VariableValue
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
@ -15,7 +15,7 @@ from core.workflow.entities.node_entities import (
NodeType,
UserFrom,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseIterationEvent,
@ -47,7 +47,28 @@ from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
logger = logging.getLogger(__name__)
class GraphEngineThreadPool(ThreadPoolExecutor):
def __init__(self, max_workers=None, thread_name_prefix='',
initializer=None, initargs=(), max_submit_count=100) -> None:
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
self.max_submit_count = max_submit_count
self.submit_count = 0
def submit(self, fn, *args, **kwargs):
self.submit_count += 1
self.check_is_full()
return super().submit(fn, *args, **kwargs)
def check_is_full(self) -> None:
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
if self.submit_count > self.max_submit_count:
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
class GraphEngine:
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
def __init__(
self,
tenant_id: str,
@ -62,10 +83,26 @@ class GraphEngine:
graph_config: Mapping[str, Any],
variable_pool: VariablePool,
max_execution_steps: int,
max_execution_time: int
max_execution_time: int,
thread_pool_id: Optional[str] = None
) -> None:
thread_pool_max_submit_count = 100
thread_pool_max_workers = 10
## init thread pool
self.thread_pool = ThreadPoolExecutor(max_workers=10)
if thread_pool_id:
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
self.thread_pool_id = thread_pool_id
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
self.is_main_thread_pool = False
else:
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
self.thread_pool_id = str(uuid.uuid4())
self.is_main_thread_pool = True
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
@ -144,6 +181,9 @@ class GraphEngine:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e))
raise e
finally:
if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
def _run(
self,
@ -196,7 +236,8 @@ class GraphEngine:
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id
)
try:
@ -357,10 +398,10 @@ class GraphEngine:
node_id = edge_mappings[0].target_node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} related parallel not found.')
raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
node_title = node_config.get('data', {}).get('title')
raise GraphRunFailedError(f'Node {node_title} related parallel not found.')
raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:

View File

@ -21,7 +21,8 @@ class BaseNode(ABC):
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None) -> None:
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
@ -35,6 +36,7 @@ class BaseNode(ABC):
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id
node_id = config.get("id")
if not node_id:

View File

@ -66,6 +66,7 @@ class ToolNode(BaseNode):
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
)
except Exception as e:
return NodeRunResult(

View File

@ -44,7 +44,8 @@ class WorkflowEntry:
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
variable_pool: VariablePool
variable_pool: VariablePool,
thread_pool_id: Optional[str] = None
) -> None:
"""
Init workflow entry
@ -59,7 +60,9 @@ class WorkflowEntry:
:param invoke_from: invoke from
:param call_depth: call depth
:param variable_pool: variable pool
:param thread_pool_id: thread pool id
"""
# check call depth
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
if call_depth > workflow_call_max_depth:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
@ -78,7 +81,8 @@ class WorkflowEntry:
graph_config=graph_config,
variable_pool=variable_pool,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=thread_pool_id
)
def run(