mirror of
https://github.com/langgenius/dify.git
synced 2026-02-22 19:15:47 +08:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@ -44,7 +44,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.5.0"
|
||||
CURRENT_DSL_VERSION = "0.6.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from configs import dify_config
|
||||
@ -9,22 +11,63 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting import RateLimit
|
||||
from core.app.features.rate_limiting.rate_limit import rate_limit_context
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.db import session_factory
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from extensions.otel import AppGenerateHandler, trace_span
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SSE_TASK_START_FALLBACK_MS = 200
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from controllers.console.app.workflow import LoopNodeRunPayload
|
||||
|
||||
|
||||
class AppGenerateService:
|
||||
@staticmethod
|
||||
def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]:
|
||||
started = False
|
||||
lock = threading.Lock()
|
||||
|
||||
def _try_start() -> bool:
|
||||
nonlocal started
|
||||
with lock:
|
||||
if started:
|
||||
return True
|
||||
try:
|
||||
start_task()
|
||||
except Exception:
|
||||
logger.exception("Failed to enqueue streaming task")
|
||||
return False
|
||||
started = True
|
||||
return True
|
||||
|
||||
# XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber.
|
||||
# The Celery task may publish the first event before the API side actually subscribes,
|
||||
# causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe,
|
||||
# but also use a short fallback timer so the task still runs if the client never consumes.
|
||||
timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start)
|
||||
timer.daemon = True
|
||||
timer.start()
|
||||
|
||||
def _on_subscribe() -> None:
|
||||
if _try_start():
|
||||
timer.cancel()
|
||||
|
||||
return _on_subscribe
|
||||
|
||||
@classmethod
|
||||
@trace_span(AppGenerateHandler)
|
||||
def generate(
|
||||
@ -88,15 +131,29 @@ class AppGenerateService:
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(
|
||||
AppMode.ADVANCED_CHAT,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
@ -104,6 +161,40 @@ class AppGenerateService:
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
if streaming:
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
root_node_id=root_node_id,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
MessageBasedAppGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory.get_session_maker(),
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
@ -112,9 +203,10 @@ class AppGenerateService:
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
streaming=False,
|
||||
root_node_id=root_node_id,
|
||||
call_depth=0,
|
||||
pause_state_config=pause_config,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
@ -248,3 +340,19 @@ class AppGenerateService:
|
||||
raise ValueError("Workflow not published")
|
||||
|
||||
return workflow
|
||||
|
||||
@classmethod
|
||||
def get_response_generator(
|
||||
cls,
|
||||
app_model: App,
|
||||
workflow_run: WorkflowRun,
|
||||
):
|
||||
if workflow_run.status.is_ended():
|
||||
# TODO(QuantumGhost): handled the ended scenario.
|
||||
pass
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
return generator.convert_to_event_stream(
|
||||
generator.retrieve_events(AppMode(app_model.mode), workflow_run.id),
|
||||
)
|
||||
|
||||
@ -136,7 +136,7 @@ class AudioService:
|
||||
message = db.session.query(Message).where(Message.id == message_id).first()
|
||||
if message is None:
|
||||
return None
|
||||
if message.answer == "" and message.status == MessageStatus.NORMAL:
|
||||
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:
|
||||
return None
|
||||
|
||||
else:
|
||||
|
||||
@ -138,6 +138,8 @@ class FeatureModel(BaseModel):
|
||||
is_allow_transfer_workspace: bool = True
|
||||
trigger_event: Quota = Quota(usage=0, limit=3000, reset_date=0)
|
||||
api_rate_limit: Quota = Quota(usage=0, limit=5000, reset_date=0)
|
||||
# Controls whether email delivery is allowed for HumanInput nodes.
|
||||
human_input_email_delivery_enabled: bool = False
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
knowledge_pipeline: KnowledgePipeline = KnowledgePipeline()
|
||||
@ -191,6 +193,11 @@ class FeatureService:
|
||||
features.knowledge_pipeline.publish_enabled = True
|
||||
cls._fulfill_params_from_workspace_info(features, tenant_id)
|
||||
|
||||
features.human_input_email_delivery_enabled = cls._resolve_human_input_email_delivery_enabled(
|
||||
features=features,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
@classmethod
|
||||
@ -203,6 +210,17 @@ class FeatureService:
|
||||
knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", CloudPlan.SANDBOX)
|
||||
return knowledge_rate_limit
|
||||
|
||||
@classmethod
|
||||
def _resolve_human_input_email_delivery_enabled(cls, *, features: FeatureModel, tenant_id: str | None) -> bool:
|
||||
if dify_config.ENTERPRISE_ENABLED or not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
if not tenant_id:
|
||||
return False
|
||||
return features.billing.enabled and features.billing.subscription.plan in (
|
||||
CloudPlan.PROFESSIONAL,
|
||||
CloudPlan.TEAM,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
|
||||
system_features = SystemFeatureModel()
|
||||
|
||||
249
api/services/human_input_delivery_test_service.py
Normal file
249
api/services/human_input_delivery_test_service.py
Normal file
@ -0,0 +1,249 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Protocol
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
ExternalRecipient,
|
||||
MemberRecipient,
|
||||
)
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_template_renderer import render_email_template
|
||||
from models import Account, TenantAccountJoin
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class DeliveryTestStatus(StrEnum):
|
||||
OK = "ok"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeliveryTestEmailRecipient:
|
||||
email: str
|
||||
form_token: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeliveryTestContext:
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
node_id: str
|
||||
node_title: str | None
|
||||
rendered_content: str
|
||||
template_vars: dict[str, str] = field(default_factory=dict)
|
||||
recipients: list[DeliveryTestEmailRecipient] = field(default_factory=list)
|
||||
variable_pool: VariablePool | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeliveryTestResult:
|
||||
status: DeliveryTestStatus
|
||||
delivered_to: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class DeliveryTestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeliveryTestUnsupportedError(DeliveryTestError):
|
||||
pass
|
||||
|
||||
|
||||
def _build_form_link(token: str | None) -> str | None:
|
||||
if not token:
|
||||
return None
|
||||
base_url = dify_config.APP_WEB_URL
|
||||
if not base_url:
|
||||
return None
|
||||
return f"{base_url.rstrip('/')}/form/{token}"
|
||||
|
||||
|
||||
class DeliveryTestHandler(Protocol):
|
||||
def supports(self, method: DeliveryChannelConfig) -> bool: ...
|
||||
|
||||
def send_test(
|
||||
self,
|
||||
*,
|
||||
context: DeliveryTestContext,
|
||||
method: DeliveryChannelConfig,
|
||||
) -> DeliveryTestResult: ...
|
||||
|
||||
|
||||
class DeliveryTestRegistry:
|
||||
def __init__(self, handlers: list[DeliveryTestHandler] | None = None) -> None:
|
||||
self._handlers = list(handlers or [])
|
||||
|
||||
def register(self, handler: DeliveryTestHandler) -> None:
|
||||
self._handlers.append(handler)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
*,
|
||||
context: DeliveryTestContext,
|
||||
method: DeliveryChannelConfig,
|
||||
) -> DeliveryTestResult:
|
||||
for handler in self._handlers:
|
||||
if handler.supports(method):
|
||||
return handler.send_test(context=context, method=method)
|
||||
raise DeliveryTestUnsupportedError("Delivery method does not support test send.")
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> DeliveryTestRegistry:
|
||||
return cls([EmailDeliveryTestHandler()])
|
||||
|
||||
|
||||
class HumanInputDeliveryTestService:
|
||||
def __init__(self, registry: DeliveryTestRegistry | None = None) -> None:
|
||||
self._registry = registry or DeliveryTestRegistry.default()
|
||||
|
||||
def send_test(
|
||||
self,
|
||||
*,
|
||||
context: DeliveryTestContext,
|
||||
method: DeliveryChannelConfig,
|
||||
) -> DeliveryTestResult:
|
||||
return self._registry.dispatch(context=context, method=method)
|
||||
|
||||
|
||||
class EmailDeliveryTestHandler:
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None) -> None:
|
||||
if session_factory is None:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
elif isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
|
||||
def supports(self, method: DeliveryChannelConfig) -> bool:
|
||||
return isinstance(method, EmailDeliveryMethod)
|
||||
|
||||
def send_test(
|
||||
self,
|
||||
*,
|
||||
context: DeliveryTestContext,
|
||||
method: DeliveryChannelConfig,
|
||||
) -> DeliveryTestResult:
|
||||
if not isinstance(method, EmailDeliveryMethod):
|
||||
raise DeliveryTestUnsupportedError("Delivery method does not support test send.")
|
||||
features = FeatureService.get_features(context.tenant_id)
|
||||
if not features.human_input_email_delivery_enabled:
|
||||
raise DeliveryTestError("Email delivery is not available for current plan.")
|
||||
if not mail.is_inited():
|
||||
raise DeliveryTestError("Mail client is not initialized.")
|
||||
|
||||
recipients = self._resolve_recipients(
|
||||
tenant_id=context.tenant_id,
|
||||
method=method,
|
||||
)
|
||||
if not recipients:
|
||||
raise DeliveryTestError("No recipients configured for delivery method.")
|
||||
|
||||
delivered: list[str] = []
|
||||
for recipient_email in recipients:
|
||||
substitutions = self._build_substitutions(
|
||||
context=context,
|
||||
recipient_email=recipient_email,
|
||||
)
|
||||
subject = render_email_template(method.config.subject, substitutions)
|
||||
templated_body = EmailDeliveryConfig.render_body_template(
|
||||
body=method.config.body,
|
||||
url=substitutions.get("form_link"),
|
||||
variable_pool=context.variable_pool,
|
||||
)
|
||||
body = render_email_template(templated_body, substitutions)
|
||||
|
||||
mail.send(
|
||||
to=recipient_email,
|
||||
subject=subject,
|
||||
html=body,
|
||||
)
|
||||
delivered.append(recipient_email)
|
||||
|
||||
return DeliveryTestResult(status=DeliveryTestStatus.OK, delivered_to=delivered)
|
||||
|
||||
def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]:
|
||||
recipients = method.config.recipients
|
||||
emails: list[str] = []
|
||||
member_user_ids: list[str] = []
|
||||
for recipient in recipients.items:
|
||||
if isinstance(recipient, MemberRecipient):
|
||||
member_user_ids.append(recipient.user_id)
|
||||
elif isinstance(recipient, ExternalRecipient):
|
||||
if recipient.email:
|
||||
emails.append(recipient.email)
|
||||
|
||||
if recipients.whole_workspace:
|
||||
member_user_ids = []
|
||||
member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None)
|
||||
emails.extend(member_emails.values())
|
||||
elif member_user_ids:
|
||||
member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids)
|
||||
for user_id in member_user_ids:
|
||||
email = member_emails.get(user_id)
|
||||
if email:
|
||||
emails.append(email)
|
||||
|
||||
return list(dict.fromkeys([email for email in emails if email]))
|
||||
|
||||
def _query_workspace_member_emails(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_ids: list[str] | None,
|
||||
) -> dict[str, str]:
|
||||
if user_ids is None:
|
||||
unique_ids = None
|
||||
else:
|
||||
unique_ids = {user_id for user_id in user_ids if user_id}
|
||||
if not unique_ids:
|
||||
return {}
|
||||
|
||||
stmt = (
|
||||
select(Account.id, Account.email)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id)
|
||||
)
|
||||
if unique_ids is not None:
|
||||
stmt = stmt.where(Account.id.in_(unique_ids))
|
||||
|
||||
with self._session_factory() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
return dict(rows)
|
||||
|
||||
@staticmethod
|
||||
def _build_substitutions(
|
||||
*,
|
||||
context: DeliveryTestContext,
|
||||
recipient_email: str,
|
||||
) -> dict[str, str]:
|
||||
raw_values: dict[str, str | None] = {
|
||||
"form_id": "",
|
||||
"node_title": context.node_title,
|
||||
"workflow_run_id": "",
|
||||
"form_token": "",
|
||||
"form_link": "",
|
||||
"form_content": context.rendered_content,
|
||||
"recipient_email": recipient_email,
|
||||
}
|
||||
substitutions = {key: value or "" for key, value in raw_values.items()}
|
||||
if context.template_vars:
|
||||
substitutions.update({key: value for key, value in context.template_vars.items() if value is not None})
|
||||
token = next(
|
||||
(recipient.form_token for recipient in context.recipients if recipient.email == recipient_email),
|
||||
None,
|
||||
)
|
||||
if token:
|
||||
substitutions["form_token"] = token
|
||||
substitutions["form_link"] = _build_form_link(token) or ""
|
||||
return substitutions
|
||||
250
api/services/human_input_service.py
Normal file
250
api/services/human_input_service.py
Normal file
@ -0,0 +1,250 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormSubmissionRepository,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormDefinition,
|
||||
HumanInputSubmissionValidationError,
|
||||
validate_human_input_submission,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
|
||||
from libs.exception import BaseHTTPException
|
||||
from models.human_input import RecipientType
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE, resume_app_execution
|
||||
|
||||
|
||||
class Form:
|
||||
def __init__(self, record: HumanInputFormRecord):
|
||||
self._record = record
|
||||
|
||||
def get_definition(self) -> FormDefinition:
|
||||
return self._record.definition
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self._record.submitted
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._record.form_id
|
||||
|
||||
@property
|
||||
def workflow_run_id(self) -> str | None:
|
||||
"""Workflow run id for runtime forms; None for delivery tests."""
|
||||
return self._record.workflow_run_id
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self._record.tenant_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str:
|
||||
return self._record.app_id
|
||||
|
||||
@property
|
||||
def recipient_id(self) -> str | None:
|
||||
return self._record.recipient_id
|
||||
|
||||
@property
|
||||
def recipient_type(self) -> RecipientType | None:
|
||||
return self._record.recipient_type
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self._record.status
|
||||
|
||||
@property
|
||||
def form_kind(self) -> HumanInputFormKind:
|
||||
return self._record.form_kind
|
||||
|
||||
@property
|
||||
def created_at(self) -> "datetime":
|
||||
return self._record.created_at
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> "datetime":
|
||||
return self._record.expiration_time
|
||||
|
||||
|
||||
class HumanInputError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FormSubmittedError(HumanInputError, BaseHTTPException):
|
||||
error_code = "human_input_form_submitted"
|
||||
description = "This form has already been submitted by another user, form_id={form_id}"
|
||||
code = 412
|
||||
|
||||
def __init__(self, form_id: str):
|
||||
template = self.description or "This form has already been submitted by another user, form_id={form_id}"
|
||||
description = template.format(form_id=form_id)
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class FormNotFoundError(HumanInputError, BaseHTTPException):
|
||||
error_code = "human_input_form_not_found"
|
||||
code = 404
|
||||
|
||||
|
||||
class InvalidFormDataError(HumanInputError, BaseHTTPException):
|
||||
error_code = "invalid_form_data"
|
||||
code = 400
|
||||
|
||||
def __init__(self, description: str):
|
||||
super().__init__(description=description)
|
||||
|
||||
|
||||
class WebAppDeliveryNotEnabledError(HumanInputError, BaseException):
|
||||
pass
|
||||
|
||||
|
||||
class FormExpiredError(HumanInputError, BaseHTTPException):
|
||||
error_code = "human_input_form_expired"
|
||||
code = 412
|
||||
|
||||
def __init__(self, form_id: str):
|
||||
super().__init__(description=f"This form has expired, form_id={form_id}")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputService:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker[Session] | Engine,
|
||||
form_repository: HumanInputFormSubmissionRepository | None = None,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory)
|
||||
|
||||
def get_form_by_token(self, form_token: str) -> Form | None:
|
||||
record = self._form_repository.get_by_token(form_token)
|
||||
if record is None:
|
||||
return None
|
||||
return Form(record)
|
||||
|
||||
def get_form_definition_by_token(self, recipient_type: RecipientType, form_token: str) -> Form | None:
|
||||
form = self.get_form_by_token(form_token)
|
||||
if form is None or form.recipient_type != recipient_type:
|
||||
return None
|
||||
self._ensure_not_submitted(form)
|
||||
return form
|
||||
|
||||
def get_form_definition_by_token_for_console(self, form_token: str) -> Form | None:
|
||||
form = self.get_form_by_token(form_token)
|
||||
if form is None or form.recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
|
||||
return None
|
||||
self._ensure_not_submitted(form)
|
||||
return form
|
||||
|
||||
def submit_form_by_token(
|
||||
self,
|
||||
recipient_type: RecipientType,
|
||||
form_token: str,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
submission_end_user_id: str | None = None,
|
||||
submission_user_id: str | None = None,
|
||||
):
|
||||
form = self.get_form_by_token(form_token)
|
||||
if form is None or form.recipient_type != recipient_type:
|
||||
raise WebAppDeliveryNotEnabledError()
|
||||
|
||||
self.ensure_form_active(form)
|
||||
self._validate_submission(form=form, selected_action_id=selected_action_id, form_data=form_data)
|
||||
|
||||
result = self._form_repository.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=form.recipient_id,
|
||||
selected_action_id=selected_action_id,
|
||||
form_data=form_data,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
|
||||
if result.form_kind != HumanInputFormKind.RUNTIME:
|
||||
return
|
||||
if result.workflow_run_id is None:
|
||||
return
|
||||
self.enqueue_resume(result.workflow_run_id)
|
||||
|
||||
def ensure_form_active(self, form: Form) -> None:
|
||||
if form.submitted:
|
||||
raise FormSubmittedError(form.id)
|
||||
if form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
|
||||
raise FormExpiredError(form.id)
|
||||
now = naive_utc_now()
|
||||
if ensure_naive_utc(form.expiration_time) <= now:
|
||||
raise FormExpiredError(form.id)
|
||||
if self._is_globally_expired(form, now=now):
|
||||
raise FormExpiredError(form.id)
|
||||
|
||||
def _ensure_not_submitted(self, form: Form) -> None:
|
||||
if form.submitted:
|
||||
raise FormSubmittedError(form.id)
|
||||
|
||||
def _validate_submission(self, form: Form, selected_action_id: str, form_data: Mapping[str, Any]) -> None:
|
||||
definition = form.get_definition()
|
||||
try:
|
||||
validate_human_input_submission(
|
||||
inputs=definition.inputs,
|
||||
user_actions=definition.user_actions,
|
||||
selected_action_id=selected_action_id,
|
||||
form_data=form_data,
|
||||
)
|
||||
except HumanInputSubmissionValidationError as exc:
|
||||
raise InvalidFormDataError(str(exc)) from exc
|
||||
|
||||
def enqueue_resume(self, workflow_run_id: str) -> None:
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory)
|
||||
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(workflow_run_id)
|
||||
|
||||
if workflow_run is None:
|
||||
raise AssertionError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
app_query = select(App).where(App.id == workflow_run.app_id)
|
||||
app = session.execute(app_query).scalar_one_or_none()
|
||||
if app is None:
|
||||
logger.error(
|
||||
"App not found for WorkflowRun, workflow_run_id=%s, app_id=%s", workflow_run_id, workflow_run.app_id
|
||||
)
|
||||
return
|
||||
|
||||
if app.mode in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
payload = {"workflow_run_id": workflow_run_id}
|
||||
try:
|
||||
resume_app_execution.apply_async(
|
||||
kwargs={"payload": payload},
|
||||
queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE,
|
||||
)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Failed to enqueue resume task for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
logger.warning("App mode %s does not support resume for workflow run %s", app.mode, workflow_run_id)
|
||||
|
||||
def _is_globally_expired(self, form: Form, *, now: datetime | None = None) -> bool:
|
||||
global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS
|
||||
if global_timeout_seconds <= 0:
|
||||
return False
|
||||
if form.workflow_run_id is None:
|
||||
return False
|
||||
current = now or naive_utc_now()
|
||||
created_at = ensure_naive_utc(form.created_at)
|
||||
global_deadline = created_at + timedelta(seconds=global_timeout_seconds)
|
||||
return global_deadline <= current
|
||||
@ -1,6 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@ -14,6 +17,10 @@ from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
|
||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import (
|
||||
SQLAlchemyExecutionExtraContentRepository,
|
||||
)
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
@ -24,6 +31,23 @@ from services.errors.message import (
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
return SQLAlchemyExecutionExtraContentRepository(session_maker=session_maker)
|
||||
|
||||
|
||||
def attach_message_extra_contents(messages: Sequence[Message]) -> None:
|
||||
if not messages:
|
||||
return
|
||||
|
||||
repository = _create_execution_extra_content_repository()
|
||||
extra_contents_lists = repository.get_by_message_ids([message.id for message in messages])
|
||||
|
||||
for index, message in enumerate(messages):
|
||||
contents = extra_contents_lists[index] if index < len(extra_contents_lists) else []
|
||||
message.set_extra_contents([content.model_dump(mode="json", exclude_none=True) for content in contents])
|
||||
|
||||
|
||||
class MessageService:
|
||||
@classmethod
|
||||
def pagination_by_first_id(
|
||||
@ -85,6 +109,8 @@ class MessageService:
|
||||
if order == "asc":
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
attach_message_extra_contents(history_messages)
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -10,6 +10,7 @@ from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
@ -63,6 +64,8 @@ class WorkflowToolManageService:
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||
|
||||
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
|
||||
|
||||
workflow_tool_provider = WorkflowToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
@ -152,6 +155,8 @@ class WorkflowToolManageService:
|
||||
if workflow is None:
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
|
||||
|
||||
workflow_tool_provider.name = name
|
||||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
|
||||
@ -98,6 +98,12 @@ class WorkflowTaskData(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class WorkflowResumeTaskData(BaseModel):
|
||||
"""Payload for workflow resumption tasks."""
|
||||
|
||||
workflow_run_id: str
|
||||
|
||||
|
||||
class AsyncTriggerExecutionResult(BaseModel):
|
||||
"""Result from async trigger-based workflow execution"""
|
||||
|
||||
|
||||
460
api/services/workflow_event_snapshot_service.py
Normal file
460
api/services/workflow_event_snapshot_service.py
Normal file
@ -0,0 +1,460 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.entities.task_entities import (
|
||||
MessageReplaceStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
StreamEvent,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.workflow.entities import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models.model import AppMode, Message
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MessageContext:
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
created_at: int
|
||||
answer: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferState:
|
||||
queue: queue.Queue[Mapping[str, Any]]
|
||||
stop_event: threading.Event
|
||||
done_event: threading.Event
|
||||
task_id_ready: threading.Event
|
||||
task_id_hint: str | None = None
|
||||
|
||||
|
||||
def build_workflow_event_stream(
|
||||
*,
|
||||
app_mode: AppMode,
|
||||
workflow_run: WorkflowRun,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
idle_timeout: float = 300,
|
||||
ping_interval: float = 10.0,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]:
|
||||
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
|
||||
message_context = (
|
||||
_get_message_context(session_maker, workflow_run.id) if app_mode == AppMode.ADVANCED_CHAT else None
|
||||
)
|
||||
|
||||
pause_entity: WorkflowPauseEntity | None = None
|
||||
if workflow_run.status == WorkflowExecutionStatus.PAUSED:
|
||||
try:
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run.id)
|
||||
except Exception:
|
||||
logger.exception("Failed to load workflow pause for run %s", workflow_run.id)
|
||||
pause_entity = None
|
||||
|
||||
resumption_context = _load_resumption_context(pause_entity)
|
||||
node_snapshots = node_execution_repo.get_execution_snapshots_by_workflow_run(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
# NOTE(QuantumGhost): for events resumption, we only care about
|
||||
# the execution records from `WORKFLOW_RUN`.
|
||||
#
|
||||
# Ideally filtering with `workflow_run_id` is enough. However,
|
||||
# due to the index of `WorkflowNodeExecution` table, we have to
|
||||
# add a filter condition of `triggered_from` to
|
||||
# ensure that we can utilize the index.
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
|
||||
def _generate() -> Generator[Mapping[str, Any] | str, None, None]:
|
||||
# send a PING event immediately to prevent the connection staying in pending state for a long time.
|
||||
#
|
||||
# This simplify the debugging process as the DevTools in Chrome does not
|
||||
# provide complete curl command for pending connections.
|
||||
yield StreamEvent.PING.value
|
||||
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
|
||||
with topic.subscribe() as sub:
|
||||
buffer_state = _start_buffering(sub)
|
||||
try:
|
||||
task_id = _resolve_task_id(resumption_context, buffer_state, workflow_run.id)
|
||||
|
||||
snapshot_events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=node_snapshots,
|
||||
task_id=task_id,
|
||||
message_context=message_context,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
)
|
||||
|
||||
for event in snapshot_events:
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
yield event
|
||||
if _is_terminal_event(event, include_paused=True):
|
||||
return
|
||||
|
||||
while True:
|
||||
if buffer_state.done_event.is_set() and buffer_state.queue.empty():
|
||||
return
|
||||
|
||||
try:
|
||||
event = buffer_state.queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
current_time = time.time()
|
||||
if current_time - last_msg_time > idle_timeout:
|
||||
logger.debug(
|
||||
"No workflow events received for %s seconds, keeping stream open",
|
||||
idle_timeout,
|
||||
)
|
||||
last_msg_time = current_time
|
||||
if current_time - last_ping_time >= ping_interval:
|
||||
yield StreamEvent.PING.value
|
||||
last_ping_time = current_time
|
||||
continue
|
||||
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
yield event
|
||||
if _is_terminal_event(event, include_paused=True):
|
||||
return
|
||||
finally:
|
||||
buffer_state.stop_event.set()
|
||||
|
||||
return _generate()
|
||||
|
||||
|
||||
def _get_message_context(session_maker: sessionmaker[Session], workflow_run_id: str) -> MessageContext | None:
|
||||
with session_maker() as session:
|
||||
stmt = select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(desc(Message.created_at))
|
||||
message = session.scalar(stmt)
|
||||
if message is None:
|
||||
return None
|
||||
created_at = int(message.created_at.timestamp()) if message.created_at else 0
|
||||
return MessageContext(
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
created_at=created_at,
|
||||
answer=message.answer,
|
||||
)
|
||||
|
||||
|
||||
def _load_resumption_context(pause_entity: WorkflowPauseEntity | None) -> WorkflowResumptionContext | None:
|
||||
if pause_entity is None:
|
||||
return None
|
||||
try:
|
||||
raw_state = pause_entity.get_state().decode()
|
||||
return WorkflowResumptionContext.loads(raw_state)
|
||||
except Exception:
|
||||
logger.exception("Failed to load resumption context")
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_task_id(
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
buffer_state: BufferState | None,
|
||||
workflow_run_id: str,
|
||||
wait_timeout: float = 0.2,
|
||||
) -> str:
|
||||
if resumption_context is not None:
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
if generate_entity.task_id:
|
||||
return generate_entity.task_id
|
||||
if buffer_state is None:
|
||||
return workflow_run_id
|
||||
if buffer_state.task_id_hint is None:
|
||||
buffer_state.task_id_ready.wait(timeout=wait_timeout)
|
||||
if buffer_state.task_id_hint:
|
||||
return buffer_state.task_id_hint
|
||||
return workflow_run_id
|
||||
|
||||
|
||||
def _build_snapshot_events(
|
||||
*,
|
||||
workflow_run: WorkflowRun,
|
||||
node_snapshots: Sequence[WorkflowNodeExecutionSnapshot],
|
||||
task_id: str,
|
||||
message_context: MessageContext | None,
|
||||
pause_entity: WorkflowPauseEntity | None,
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
events: list[Mapping[str, Any]] = []
|
||||
|
||||
workflow_started = _build_workflow_started_event(
|
||||
workflow_run=workflow_run,
|
||||
task_id=task_id,
|
||||
)
|
||||
_apply_message_context(workflow_started, message_context)
|
||||
events.append(workflow_started)
|
||||
|
||||
if message_context is not None and message_context.answer is not None:
|
||||
message_replace = _build_message_replace_event(task_id=task_id, answer=message_context.answer)
|
||||
_apply_message_context(message_replace, message_context)
|
||||
events.append(message_replace)
|
||||
|
||||
for snapshot in node_snapshots:
|
||||
node_started = _build_node_started_event(
|
||||
workflow_run_id=workflow_run.id,
|
||||
task_id=task_id,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
_apply_message_context(node_started, message_context)
|
||||
events.append(node_started)
|
||||
|
||||
if snapshot.status != WorkflowNodeExecutionStatus.RUNNING.value:
|
||||
node_finished = _build_node_finished_event(
|
||||
workflow_run_id=workflow_run.id,
|
||||
task_id=task_id,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
_apply_message_context(node_finished, message_context)
|
||||
events.append(node_finished)
|
||||
|
||||
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
|
||||
pause_event = _build_pause_event(
|
||||
workflow_run=workflow_run,
|
||||
workflow_run_id=workflow_run.id,
|
||||
task_id=task_id,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
)
|
||||
if pause_event is not None:
|
||||
_apply_message_context(pause_event, message_context)
|
||||
events.append(pause_event)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def _build_workflow_started_event(
|
||||
*,
|
||||
workflow_run: WorkflowRun,
|
||||
task_id: str,
|
||||
) -> dict[str, Any]:
|
||||
response = WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
inputs=workflow_run.inputs_dict or {},
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
),
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
return payload
|
||||
|
||||
|
||||
def _build_message_replace_event(*, task_id: str, answer: str) -> dict[str, Any]:
|
||||
response = MessageReplaceStreamResponse(
|
||||
task_id=task_id,
|
||||
answer=answer,
|
||||
reason="",
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
return payload
|
||||
|
||||
|
||||
def _build_node_started_event(
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
task_id: str,
|
||||
snapshot: WorkflowNodeExecutionSnapshot,
|
||||
) -> dict[str, Any]:
|
||||
created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=snapshot.execution_id,
|
||||
node_id=snapshot.node_id,
|
||||
node_type=snapshot.node_type,
|
||||
title=snapshot.title,
|
||||
index=snapshot.index,
|
||||
predecessor_node_id=None,
|
||||
inputs=None,
|
||||
created_at=created_at,
|
||||
extras={},
|
||||
iteration_id=snapshot.iteration_id,
|
||||
loop_id=snapshot.loop_id,
|
||||
),
|
||||
)
|
||||
return response.to_ignore_detail_dict()
|
||||
|
||||
|
||||
def _build_node_finished_event(
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
task_id: str,
|
||||
snapshot: WorkflowNodeExecutionSnapshot,
|
||||
) -> dict[str, Any]:
|
||||
created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0
|
||||
finished_at = int(snapshot.finished_at.timestamp()) if snapshot.finished_at else created_at
|
||||
response = NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=snapshot.execution_id,
|
||||
node_id=snapshot.node_id,
|
||||
node_type=snapshot.node_type,
|
||||
title=snapshot.title,
|
||||
index=snapshot.index,
|
||||
predecessor_node_id=None,
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status=WorkflowNodeExecutionStatus(snapshot.status),
|
||||
error=None,
|
||||
elapsed_time=snapshot.elapsed_time,
|
||||
execution_metadata=None,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
files=[],
|
||||
iteration_id=snapshot.iteration_id,
|
||||
loop_id=snapshot.loop_id,
|
||||
),
|
||||
)
|
||||
return response.to_ignore_detail_dict()
|
||||
|
||||
|
||||
def _build_pause_event(
|
||||
*,
|
||||
workflow_run: WorkflowRun,
|
||||
workflow_run_id: str,
|
||||
task_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
) -> dict[str, Any] | None:
|
||||
paused_nodes: list[str] = []
|
||||
outputs: dict[str, Any] = {}
|
||||
if resumption_context is not None:
|
||||
state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
paused_nodes = state.get_paused_nodes()
|
||||
outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {}))
|
||||
|
||||
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
|
||||
response = WorkflowPauseStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id=workflow_run_id,
|
||||
paused_nodes=paused_nodes,
|
||||
outputs=outputs,
|
||||
reasons=reasons,
|
||||
status=workflow_run.status,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
elapsed_time=float(workflow_run.elapsed_time or 0.0),
|
||||
total_tokens=int(workflow_run.total_tokens or 0),
|
||||
total_steps=int(workflow_run.total_steps or 0),
|
||||
),
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
return payload
|
||||
|
||||
|
||||
def _apply_message_context(payload: dict[str, Any], message_context: MessageContext | None) -> None:
|
||||
if message_context is None:
|
||||
return
|
||||
payload["conversation_id"] = message_context.conversation_id
|
||||
payload["message_id"] = message_context.message_id
|
||||
payload["created_at"] = message_context.created_at
|
||||
|
||||
|
||||
def _start_buffering(subscription) -> BufferState:
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(maxsize=2048),
|
||||
stop_event=threading.Event(),
|
||||
done_event=threading.Event(),
|
||||
task_id_ready=threading.Event(),
|
||||
)
|
||||
|
||||
def _worker() -> None:
|
||||
dropped_count = 0
|
||||
try:
|
||||
while not buffer_state.stop_event.is_set():
|
||||
msg = subscription.receive(timeout=0.1)
|
||||
if msg is None:
|
||||
continue
|
||||
event = _parse_event_message(msg)
|
||||
if event is None:
|
||||
continue
|
||||
task_id = event.get("task_id")
|
||||
if task_id and buffer_state.task_id_hint is None:
|
||||
buffer_state.task_id_hint = str(task_id)
|
||||
buffer_state.task_id_ready.set()
|
||||
try:
|
||||
buffer_state.queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
dropped_count += 1
|
||||
try:
|
||||
buffer_state.queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
try:
|
||||
buffer_state.queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
continue
|
||||
logger.warning("Dropped buffered workflow event, total_dropped=%s", dropped_count)
|
||||
except Exception:
|
||||
logger.exception("Failed while buffering workflow events")
|
||||
finally:
|
||||
buffer_state.done_event.set()
|
||||
|
||||
thread = threading.Thread(target=_worker, name=f"workflow-event-buffer-{id(subscription)}", daemon=True)
|
||||
thread.start()
|
||||
return buffer_state
|
||||
|
||||
|
||||
def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
|
||||
try:
|
||||
event = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to decode workflow event payload")
|
||||
return None
|
||||
if not isinstance(event, dict):
|
||||
return None
|
||||
return event
|
||||
|
||||
|
||||
def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool:
|
||||
if not isinstance(event, Mapping):
|
||||
return False
|
||||
event_type = event.get("event")
|
||||
if event_type == StreamEvent.WORKFLOW_FINISHED.value:
|
||||
return True
|
||||
if include_paused:
|
||||
return event_type == StreamEvent.WORKFLOW_PAUSED.value
|
||||
return False
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
@ -11,21 +12,34 @@ from configs import dify_config
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.variables import VariableBase
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.entities import GraphInitParams, WorkflowNodeExecution
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
HumanInputNodeData,
|
||||
apply_debug_email_recipient,
|
||||
validate_human_input_submission,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.repositories.human_input_form_repository import FormCreateParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
@ -34,6 +48,8 @@ from extensions.ext_storage import storage
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import UserFrom
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
|
||||
@ -44,6 +60,13 @@ from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededEr
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
from .human_input_delivery_test_service import (
|
||||
DeliveryTestContext,
|
||||
DeliveryTestEmailRecipient,
|
||||
DeliveryTestError,
|
||||
DeliveryTestUnsupportedError,
|
||||
HumanInputDeliveryTestService,
|
||||
)
|
||||
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
|
||||
@ -744,6 +767,344 @@ class WorkflowService:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def get_human_input_form_preview(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
account: Account,
|
||||
node_id: str,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Build a human input form preview for a draft workflow.
|
||||
|
||||
Args:
|
||||
app_model: Target application model.
|
||||
account: Current account.
|
||||
node_id: Human input node ID.
|
||||
inputs: Values used to fill missing upstream variables referenced in form_content.
|
||||
"""
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id)
|
||||
node_type = Workflow.get_node_type_from_node_config(node_config)
|
||||
if node_type is not NodeType.HUMAN_INPUT:
|
||||
raise ValueError("Node type must be human-input.")
|
||||
|
||||
# inputs: values used to fill missing upstream variables referenced in form_content.
|
||||
variable_pool = self._build_human_input_variable_pool(
|
||||
app_model=app_model,
|
||||
workflow=draft_workflow,
|
||||
node_config=node_config,
|
||||
manual_inputs=inputs or {},
|
||||
)
|
||||
node = self._build_human_input_node(
|
||||
workflow=draft_workflow,
|
||||
account=account,
|
||||
node_config=node_config,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
rendered_content = node.render_form_content_before_submission()
|
||||
resolved_default_values = node.resolve_default_values()
|
||||
node_data = node.node_data
|
||||
human_input_required = HumanInputRequired(
|
||||
form_id=node_id,
|
||||
form_content=rendered_content,
|
||||
inputs=node_data.inputs,
|
||||
actions=node_data.user_actions,
|
||||
node_id=node_id,
|
||||
node_title=node.title,
|
||||
resolved_default_values=resolved_default_values,
|
||||
form_token=None,
|
||||
)
|
||||
return human_input_required.model_dump(mode="json")
|
||||
|
||||
def submit_human_input_form_preview(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
account: Account,
|
||||
node_id: str,
|
||||
form_inputs: Mapping[str, Any],
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
action: str,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Submit a human input form preview for a draft workflow.
|
||||
|
||||
Args:
|
||||
app_model: Target application model.
|
||||
account: Current account.
|
||||
node_id: Human input node ID.
|
||||
form_inputs: Values the user provides for the form's own fields.
|
||||
inputs: Values used to fill missing upstream variables referenced in form_content.
|
||||
action: Selected action ID.
|
||||
"""
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id)
|
||||
node_type = Workflow.get_node_type_from_node_config(node_config)
|
||||
if node_type is not NodeType.HUMAN_INPUT:
|
||||
raise ValueError("Node type must be human-input.")
|
||||
|
||||
# inputs: values used to fill missing upstream variables referenced in form_content.
|
||||
# form_inputs: values the user provides for the form's own fields.
|
||||
variable_pool = self._build_human_input_variable_pool(
|
||||
app_model=app_model,
|
||||
workflow=draft_workflow,
|
||||
node_config=node_config,
|
||||
manual_inputs=inputs or {},
|
||||
)
|
||||
node = self._build_human_input_node(
|
||||
workflow=draft_workflow,
|
||||
account=account,
|
||||
node_config=node_config,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
node_data = node.node_data
|
||||
|
||||
validate_human_input_submission(
|
||||
inputs=node_data.inputs,
|
||||
user_actions=node_data.user_actions,
|
||||
selected_action_id=action,
|
||||
form_data=form_inputs,
|
||||
)
|
||||
|
||||
rendered_content = node.render_form_content_before_submission()
|
||||
outputs: dict[str, Any] = dict(form_inputs)
|
||||
outputs["__action_id"] = action
|
||||
outputs["__rendered_content"] = node.render_form_content_with_outputs(
|
||||
rendered_content, outputs, node_data.outputs_field_names()
|
||||
)
|
||||
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
node_type=NodeType.HUMAN_INPUT,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
user=account,
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
)
|
||||
draft_var_saver.save(outputs=outputs, process_data={})
|
||||
session.commit()
|
||||
|
||||
return outputs
|
||||
|
||||
def test_human_input_delivery(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
account: Account,
|
||||
node_id: str,
|
||||
delivery_method_id: str,
|
||||
inputs: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id)
|
||||
node_type = Workflow.get_node_type_from_node_config(node_config)
|
||||
if node_type is not NodeType.HUMAN_INPUT:
|
||||
raise ValueError("Node type must be human-input.")
|
||||
|
||||
node_data = HumanInputNodeData.model_validate(node_config.get("data", {}))
|
||||
delivery_method = self._resolve_human_input_delivery_method(
|
||||
node_data=node_data,
|
||||
delivery_method_id=delivery_method_id,
|
||||
)
|
||||
if delivery_method is None:
|
||||
raise ValueError("Delivery method not found.")
|
||||
delivery_method = apply_debug_email_recipient(
|
||||
delivery_method,
|
||||
enabled=True,
|
||||
user_id=account.id or "",
|
||||
)
|
||||
|
||||
variable_pool = self._build_human_input_variable_pool(
|
||||
app_model=app_model,
|
||||
workflow=draft_workflow,
|
||||
node_config=node_config,
|
||||
manual_inputs=inputs or {},
|
||||
)
|
||||
node = self._build_human_input_node(
|
||||
workflow=draft_workflow,
|
||||
account=account,
|
||||
node_config=node_config,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
rendered_content = node.render_form_content_before_submission()
|
||||
resolved_default_values = node.resolve_default_values()
|
||||
form_id, recipients = self._create_human_input_delivery_test_form(
|
||||
app_model=app_model,
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
delivery_method=delivery_method,
|
||||
rendered_content=rendered_content,
|
||||
resolved_default_values=resolved_default_values,
|
||||
)
|
||||
test_service = HumanInputDeliveryTestService()
|
||||
context = DeliveryTestContext(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
node_title=node_data.title,
|
||||
rendered_content=rendered_content,
|
||||
template_vars={"form_id": form_id},
|
||||
recipients=recipients,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
try:
|
||||
test_service.send_test(context=context, method=delivery_method)
|
||||
except DeliveryTestUnsupportedError as exc:
|
||||
raise ValueError("Delivery method does not support test send.") from exc
|
||||
except DeliveryTestError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
@staticmethod
|
||||
def _resolve_human_input_delivery_method(
|
||||
*,
|
||||
node_data: HumanInputNodeData,
|
||||
delivery_method_id: str,
|
||||
) -> DeliveryChannelConfig | None:
|
||||
for method in node_data.delivery_methods:
|
||||
if str(method.id) == delivery_method_id:
|
||||
return method
|
||||
return None
|
||||
|
||||
def _create_human_input_delivery_test_form(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
node_id: str,
|
||||
node_data: HumanInputNodeData,
|
||||
delivery_method: DeliveryChannelConfig,
|
||||
rendered_content: str,
|
||||
resolved_default_values: Mapping[str, Any],
|
||||
) -> tuple[str, list[DeliveryTestEmailRecipient]]:
|
||||
repo = HumanInputFormRepositoryImpl(session_factory=db.engine, tenant_id=app_model.tenant_id)
|
||||
params = FormCreateParams(
|
||||
app_id=app_model.id,
|
||||
workflow_execution_id=None,
|
||||
node_id=node_id,
|
||||
form_config=node_data,
|
||||
rendered_content=rendered_content,
|
||||
delivery_methods=[delivery_method],
|
||||
display_in_ui=False,
|
||||
resolved_default_values=resolved_default_values,
|
||||
form_kind=HumanInputFormKind.DELIVERY_TEST,
|
||||
)
|
||||
form_entity = repo.create_form(params)
|
||||
return form_entity.id, self._load_email_recipients(form_entity.id)
|
||||
|
||||
@staticmethod
|
||||
def _load_email_recipients(form_id: str) -> list[DeliveryTestEmailRecipient]:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with Session(bind=db.engine) as session:
|
||||
recipients = session.scalars(
|
||||
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id)
|
||||
).all()
|
||||
recipients_data: list[DeliveryTestEmailRecipient] = []
|
||||
for recipient in recipients:
|
||||
if recipient.recipient_type not in {RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL}:
|
||||
continue
|
||||
if not recipient.access_token:
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(recipient.recipient_payload)
|
||||
except Exception:
|
||||
logger.exception("Failed to parse human input recipient payload for delivery test.")
|
||||
continue
|
||||
email = payload.get("email")
|
||||
if email:
|
||||
recipients_data.append(DeliveryTestEmailRecipient(email=email, form_token=recipient.access_token))
|
||||
return recipients_data
|
||||
|
||||
def _build_human_input_node(
|
||||
self,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
account: Account,
|
||||
node_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> HumanInputNode:
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=account.id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.DEBUGGER.value,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
node = HumanInputNode(
|
||||
id=node_config.get("id", str(uuid.uuid4())),
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
return node
|
||||
|
||||
def _build_human_input_variable_pool(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_config: Mapping[str, Any],
|
||||
manual_inputs: Mapping[str, Any],
|
||||
) -> VariablePool:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
variable_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
)
|
||||
variable_mapping = HumanInputNode.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=node_config,
|
||||
)
|
||||
normalized_user_inputs: dict[str, Any] = dict(manual_inputs)
|
||||
|
||||
load_into_variable_pool(
|
||||
variable_loader=variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=normalized_user_inputs,
|
||||
)
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=normalized_user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=app_model.tenant_id,
|
||||
)
|
||||
|
||||
return variable_pool
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
@ -945,6 +1306,13 @@ class WorkflowService:
|
||||
if any(nt.is_trigger_node for nt in node_types):
|
||||
raise ValueError("Start node and trigger nodes cannot coexist in the same workflow")
|
||||
|
||||
for node in node_configs:
|
||||
node_data = node.get("data", {})
|
||||
node_type = node_data.get("type")
|
||||
|
||||
if node_type == NodeType.HUMAN_INPUT:
|
||||
self._validate_human_input_node_data(node_data)
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
@ -957,6 +1325,23 @@ class WorkflowService:
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
|
||||
def _validate_human_input_node_data(self, node_data: dict) -> None:
|
||||
"""
|
||||
Validate HumanInput node data format.
|
||||
|
||||
Args:
|
||||
node_data: The node data dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If the node data format is invalid
|
||||
"""
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
|
||||
try:
|
||||
HumanInputNodeData.model_validate(node_data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
|
||||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
) -> Workflow | None:
|
||||
|
||||
Reference in New Issue
Block a user