mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 18:06:14 +08:00
WIP: resume
This commit is contained in:
@ -8,18 +8,21 @@ from typing import Annotated, Any, TypeAlias, Union
|
||||
from celery import shared_task
|
||||
from flask import current_app, json
|
||||
from pydantic import BaseModel, Discriminator, Field, Tag
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from extensions.ext_database import db
|
||||
from libs.flask_utils import set_login_user
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import App, AppMode, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -129,6 +132,11 @@ class _ChatflowRunner:
|
||||
workflow = session.get(Workflow, exec_params.workflow_id)
|
||||
app = session.get(App, workflow.app_id)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=self._session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
user = self._resolve_user()
|
||||
|
||||
chat_generator = AdvancedChatAppGenerator()
|
||||
@ -144,6 +152,7 @@ class _ChatflowRunner:
|
||||
invoke_from=exec_params.invoke_from,
|
||||
streaming=exec_params.streaming,
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
if not exec_params.streaming:
|
||||
return response
|
||||
@ -174,11 +183,135 @@ class _ChatflowRunner:
|
||||
return user
|
||||
|
||||
|
||||
def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None:
|
||||
role = CreatorUserRole(workflow_run.created_by_role)
|
||||
if role == CreatorUserRole.ACCOUNT:
|
||||
user = session.get(Account, workflow_run.created_by)
|
||||
if user:
|
||||
user.set_tenant_id(workflow_run.tenant_id)
|
||||
return user
|
||||
|
||||
return session.get(EndUser, workflow_run.created_by)
|
||||
|
||||
|
||||
@shared_task(queue="chatflow_execute")
|
||||
def chatflow_execute_task(payload: str) -> Mapping[str, Any] | None:
|
||||
exec_params = ChatflowExecutionParams.model_validate_json(payload)
|
||||
|
||||
print("chatflow_execute_task run with params", exec_params)
|
||||
logger.info("chatflow_execute_task run with params: %s", exec_params)
|
||||
|
||||
runner = _ChatflowRunner(db.engine, exec_params=exec_params)
|
||||
return runner.run()
|
||||
|
||||
|
||||
@shared_task(queue="chatflow_execute", name="resume_chatflow_execution")
|
||||
def resume_chatflow_execution(payload: dict[str, Any]) -> None:
|
||||
workflow_run_id = payload["workflow_run_id"]
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_factory)
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
if pause_entity is None:
|
||||
logger.warning("No pause entity found for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
try:
|
||||
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
||||
except Exception:
|
||||
logger.exception("Failed to load resumption context for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
if not isinstance(generate_entity, AdvancedChatAppGenerateEntity):
|
||||
logger.error(
|
||||
"Resumption entity is not AdvancedChatAppGenerateEntity for workflow run %s (found %s)",
|
||||
workflow_run_id,
|
||||
type(generate_entity),
|
||||
)
|
||||
return
|
||||
|
||||
graph_runtime_state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
if workflow_run is None:
|
||||
logger.warning("Workflow run %s not found during resume", workflow_run_id)
|
||||
return
|
||||
|
||||
workflow = session.get(Workflow, workflow_run.workflow_id)
|
||||
if workflow is None:
|
||||
logger.warning("Workflow %s not found during resume", workflow_run.workflow_id)
|
||||
return
|
||||
|
||||
app_model = session.get(App, workflow_run.app_id)
|
||||
if app_model is None:
|
||||
logger.warning("App %s not found during resume", workflow_run.app_id)
|
||||
return
|
||||
|
||||
if generate_entity.conversation_id is None:
|
||||
logger.warning("Conversation id missing in resumption context for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
conversation = session.get(Conversation, generate_entity.conversation_id)
|
||||
if conversation is None:
|
||||
logger.warning(
|
||||
"Conversation %s not found for workflow run %s", generate_entity.conversation_id, workflow_run_id
|
||||
)
|
||||
return
|
||||
|
||||
message = session.scalar(
|
||||
select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc())
|
||||
)
|
||||
if message is None:
|
||||
logger.warning("Message not found for workflow run %s", workflow_run_id)
|
||||
return
|
||||
|
||||
user = _resolve_user_for_run(session, workflow_run)
|
||||
if user is None:
|
||||
logger.warning("User %s not found for workflow run %s", workflow_run.created_by, workflow_run_id)
|
||||
return
|
||||
|
||||
workflow_run_repo.resume_workflow_pause(workflow_run_id, pause_entity)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
try:
|
||||
triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from)
|
||||
except ValueError:
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
try:
|
||||
generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
application_generate_entity=generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id)
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user