mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 18:27:27 +08:00
feat: Implement multi-threading to get the target run results list[node_run_result_mapping] in evaluation_service.
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from openpyxl import Workbook, load_workbook
|
||||
@ -18,6 +19,8 @@ from core.evaluation.entities.evaluation_entity import (
|
||||
EvaluationRunRequest,
|
||||
)
|
||||
from core.evaluation.evaluation_manager import EvaluationManager
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.node_events.base import NodeRunResult
|
||||
from models.evaluation import (
|
||||
EvaluationConfiguration,
|
||||
EvaluationRun,
|
||||
@ -446,6 +449,201 @@ class EvaluationService:
|
||||
continue
|
||||
return EvaluationCategory.LLM
|
||||
|
||||
@classmethod
|
||||
def execute_targets(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
input_list: list[EvaluationItemInput],
|
||||
max_workers: int = 5,
|
||||
) -> list[dict[str, NodeRunResult]]:
|
||||
"""Execute the evaluation target for every test-data item in parallel.
|
||||
|
||||
:param tenant_id: Workspace / tenant ID.
|
||||
:param target_type: ``"app"`` or ``"snippet"``.
|
||||
:param target_id: ID of the App or CustomizedSnippet.
|
||||
:param input_list: All test-data items parsed from the dataset.
|
||||
:param max_workers: Maximum number of parallel worker threads.
|
||||
:return: Ordered list of ``{node_id: NodeRunResult}`` mappings. The
|
||||
*i*-th element corresponds to ``input_list[i]``. If a target
|
||||
execution fails, the corresponding element is an empty dict.
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
flask_app: Flask = current_app._get_current_object() # type: ignore
|
||||
|
||||
def _worker(item: EvaluationItemInput) -> dict[str, NodeRunResult]:
|
||||
with flask_app.app_context():
|
||||
from models.engine import db
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as thread_session:
|
||||
try:
|
||||
# 1. Execute target (workflow app / snippet)
|
||||
response = cls._run_single_target(
|
||||
session=thread_session,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
item=item,
|
||||
)
|
||||
|
||||
# 2. Extract workflow_run_id from the blocking response
|
||||
workflow_run_id = cls._extract_workflow_run_id(response)
|
||||
if not workflow_run_id:
|
||||
logger.warning(
|
||||
"No workflow_run_id for item %d (target=%s)",
|
||||
item.index,
|
||||
target_id,
|
||||
)
|
||||
return {}
|
||||
|
||||
# 3. Query per-node execution results from DB
|
||||
return cls._query_node_run_results(
|
||||
session=thread_session,
|
||||
tenant_id=tenant_id,
|
||||
app_id=target_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Target execution failed for item %d (target=%s)",
|
||||
item.index,
|
||||
target_id,
|
||||
)
|
||||
return {}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(_worker, item) for item in input_list]
|
||||
ordered_results: list[dict[str, NodeRunResult]] = []
|
||||
for future in futures:
|
||||
try:
|
||||
ordered_results.append(future.result())
|
||||
except Exception:
|
||||
logger.exception("Unexpected error collecting target execution result")
|
||||
ordered_results.append({})
|
||||
|
||||
return ordered_results
|
||||
|
||||
@classmethod
|
||||
def _run_single_target(
|
||||
cls,
|
||||
session: Session,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
item: EvaluationItemInput,
|
||||
) -> Mapping[str, object]:
|
||||
"""Execute a single evaluation target with one test-data item.
|
||||
|
||||
Dispatches to the appropriate execution service based on
|
||||
``target_type``:
|
||||
|
||||
* ``"snippet"`` → :meth:`SnippetGenerateService.run_published`
|
||||
* ``"app"`` → :meth:`WorkflowAppGenerator().generate` (blocking mode)
|
||||
|
||||
:returns: The blocking response mapping from the workflow engine.
|
||||
:raises ValueError: If the target is not found or not published.
|
||||
"""
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.evaluation.runners import get_service_account_for_app, get_service_account_for_snippet
|
||||
|
||||
if target_type == "snippet":
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
|
||||
snippet = session.query(CustomizedSnippet).filter_by(id=target_id).first()
|
||||
if not snippet:
|
||||
raise ValueError(f"Snippet {target_id} not found")
|
||||
|
||||
service_account = get_service_account_for_snippet(session, target_id)
|
||||
|
||||
return SnippetGenerateService.run_published(
|
||||
snippet=snippet,
|
||||
user=service_account,
|
||||
args={"inputs": item.inputs},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
else:
|
||||
# target_type == "app"
|
||||
app = session.query(App).filter_by(id=target_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App {target_id} not found")
|
||||
|
||||
service_account = get_service_account_for_app(session, target_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not workflow:
|
||||
raise ValueError(f"No published workflow for app {target_id}")
|
||||
|
||||
response: Mapping[str, object] = WorkflowAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=service_account,
|
||||
args={"inputs": item.inputs},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _extract_workflow_run_id(response: Mapping[str, object]) -> str | None:
|
||||
"""Extract ``workflow_run_id`` from a blocking workflow response.
|
||||
"""
|
||||
wf_run_id = response.get("workflow_run_id")
|
||||
if wf_run_id:
|
||||
return str(wf_run_id)
|
||||
data = response.get("data")
|
||||
if isinstance(data, Mapping) and data.get("id"):
|
||||
return str(data["id"])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _query_node_run_results(
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> dict[str, NodeRunResult]:
|
||||
"""Query all node execution records for a workflow run."""
|
||||
from sqlalchemy import asc, select
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
stmt = WorkflowNodeExecutionModel.preload_offload_data(
|
||||
select(WorkflowNodeExecutionModel)
|
||||
).where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
).order_by(asc(WorkflowNodeExecutionModel.created_at))
|
||||
|
||||
node_models: list[WorkflowNodeExecutionModel] = list(session.execute(stmt).scalars().all())
|
||||
|
||||
result: dict[str, NodeRunResult] = {}
|
||||
for node in node_models:
|
||||
# Convert string-keyed metadata to WorkflowNodeExecutionMetadataKey-keyed
|
||||
raw_metadata = node.execution_metadata_dict
|
||||
typed_metadata: dict[WorkflowNodeExecutionMetadataKey, object] = {}
|
||||
for key, val in raw_metadata.items():
|
||||
try:
|
||||
typed_metadata[WorkflowNodeExecutionMetadataKey(key)] = val
|
||||
except ValueError:
|
||||
pass # skip unknown metadata keys
|
||||
|
||||
result[node.node_id] = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus(node.status),
|
||||
inputs=node.inputs_dict or {},
|
||||
process_data=node.process_data_dict or {},
|
||||
outputs=node.outputs_dict or {},
|
||||
metadata=typed_metadata,
|
||||
error=node.error or "",
|
||||
)
|
||||
return result
|
||||
|
||||
# ---- Dataset Parsing ----
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user