mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 05:06:15 +08:00
Merge branch 'fix/auto-activate-credential-on-create' into deploy/dev
This commit is contained in:
@ -23,7 +23,7 @@ from dify_graph.variables.types import SegmentType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
@ -100,6 +100,18 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
}
|
||||
|
||||
|
||||
def _ensure_variable_access(
|
||||
variable: WorkflowDraftVariable | None,
|
||||
app_id: str,
|
||||
variable_id: str,
|
||||
) -> WorkflowDraftVariable:
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_id or variable.user_id != current_user.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
@ -238,6 +250,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
app_id=app_model.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
@ -250,7 +263,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||
draft_var_srv.delete_user_workflow_variables(app_model.id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
@ -287,7 +300,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id, user_id=current_user.id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@ -298,7 +311,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(app_model.id, node_id)
|
||||
srv.delete_node_variables(app_model.id, node_id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
@ -319,11 +332,11 @@ class VariableApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
return variable
|
||||
|
||||
@console_ns.doc("update_variable")
|
||||
@ -360,11 +373,11 @@ class VariableApi(Resource):
|
||||
)
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
@ -397,11 +410,11 @@ class VariableApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
@ -427,11 +440,11 @@ class VariableResetApi(Resource):
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
@ -447,11 +460,15 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||
draft_vars = draft_var_srv.list_node_variables(
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return draft_vars
|
||||
|
||||
|
||||
@ -472,7 +489,7 @@ class ConversationVariableCollectionApi(Resource):
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
@ -102,6 +102,7 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
app_id=pipeline.id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
@ -111,7 +112,7 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(pipeline.id)
|
||||
draft_var_srv.delete_user_workflow_variables(pipeline.id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
@ -144,7 +145,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
|
||||
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id, user_id=current_user.id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@ -152,7 +153,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
def delete(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(pipeline.id, node_id)
|
||||
srv.delete_node_variables(pipeline.id, node_id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
@ -283,11 +284,11 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user.id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
|
||||
@ -330,9 +330,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
@ -413,9 +414,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
|
||||
@ -419,11 +419,12 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
@ -514,11 +515,12 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
||||
@ -417,11 +417,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
@ -500,11 +501,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
|
||||
@ -0,0 +1,69 @@
|
||||
"""add user_id and switch workflow_draft_variables unique key to user scope
|
||||
|
||||
Revision ID: 6b5f9f8b1a2c
|
||||
Revises: 0ec65df55790
|
||||
Create Date: 2026-03-04 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6b5f9f8b1a2c"
|
||||
down_revision = "0ec65df55790"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _is_pg(conn) -> bool:
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
table_name = "workflow_draft_variables"
|
||||
|
||||
with op.batch_alter_table(table_name, schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("user_id", models.types.StringUUID(), nullable=True))
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.get_context().autocommit_block():
|
||||
op.create_index(
|
||||
"workflow_draft_variables_app_id_user_id_key",
|
||||
"workflow_draft_variables",
|
||||
["app_id", "user_id", "node_id", "name"],
|
||||
unique=True,
|
||||
postgresql_concurrently=True,
|
||||
)
|
||||
else:
|
||||
op.create_index(
|
||||
"workflow_draft_variables_app_id_user_id_key",
|
||||
"workflow_draft_variables",
|
||||
["app_id", "user_id", "node_id", "name"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
with op.batch_alter_table(table_name, schema=None) as batch_op:
|
||||
batch_op.drop_constraint(op.f("workflow_draft_variables_app_id_key"), type_="unique")
|
||||
|
||||
|
||||
def downgrade():
|
||||
conn = op.get_bind()
|
||||
|
||||
with op.batch_alter_table("workflow_draft_variables", schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint(
|
||||
op.f("workflow_draft_variables_app_id_key"),
|
||||
["app_id", "node_id", "name"],
|
||||
)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.get_context().autocommit_block():
|
||||
op.drop_index("workflow_draft_variables_app_id_user_id_key", postgresql_concurrently=True)
|
||||
else:
|
||||
op.drop_index("workflow_draft_variables_app_id_user_id_key", table_name="workflow_draft_variables")
|
||||
|
||||
with op.batch_alter_table("workflow_draft_variables", schema=None) as batch_op:
|
||||
batch_op.drop_column("user_id")
|
||||
@ -1286,16 +1286,17 @@ class WorkflowDraftVariable(Base):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def unique_app_id_node_id_name() -> list[str]:
|
||||
def unique_app_id_user_id_node_id_name() -> list[str]:
|
||||
return [
|
||||
"app_id",
|
||||
"user_id",
|
||||
"node_id",
|
||||
"name",
|
||||
]
|
||||
|
||||
__tablename__ = "workflow_draft_variables"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(*unique_app_id_node_id_name()),
|
||||
UniqueConstraint(*unique_app_id_user_id_node_id_name()),
|
||||
Index("workflow_draft_variable_file_id_idx", "file_id"),
|
||||
)
|
||||
# Required for instance variable annotation.
|
||||
@ -1321,6 +1322,11 @@ class WorkflowDraftVariable(Base):
|
||||
|
||||
# "`app_id` maps to the `id` field in the `model.App` model."
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# Owner of this draft variable.
|
||||
#
|
||||
# This field is nullable during migration and will be migrated to NOT NULL
|
||||
# in a follow-up release.
|
||||
user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
|
||||
# `last_edited_at` records when the value of a given draft variable
|
||||
# is edited.
|
||||
@ -1573,6 +1579,7 @@ class WorkflowDraftVariable(Base):
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
user_id: str | None,
|
||||
node_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
@ -1586,6 +1593,7 @@ class WorkflowDraftVariable(Base):
|
||||
variable.updated_at = naive_utc_now()
|
||||
variable.description = description
|
||||
variable.app_id = app_id
|
||||
variable.user_id = user_id
|
||||
variable.node_id = node_id
|
||||
variable.name = name
|
||||
variable.set_value(value)
|
||||
@ -1599,12 +1607,14 @@ class WorkflowDraftVariable(Base):
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
user_id: str | None = None,
|
||||
name: str,
|
||||
value: Segment,
|
||||
description: str = "",
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
node_id=CONVERSATION_VARIABLE_NODE_ID,
|
||||
name=name,
|
||||
value=value,
|
||||
@ -1619,6 +1629,7 @@ class WorkflowDraftVariable(Base):
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
user_id: str | None = None,
|
||||
name: str,
|
||||
value: Segment,
|
||||
node_execution_id: str,
|
||||
@ -1626,6 +1637,7 @@ class WorkflowDraftVariable(Base):
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||
name=name,
|
||||
node_execution_id=node_execution_id,
|
||||
@ -1639,6 +1651,7 @@ class WorkflowDraftVariable(Base):
|
||||
cls,
|
||||
*,
|
||||
app_id: str,
|
||||
user_id: str | None = None,
|
||||
node_id: str,
|
||||
name: str,
|
||||
value: Segment,
|
||||
@ -1649,6 +1662,7 @@ class WorkflowDraftVariable(Base):
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
node_execution_id=node_execution_id,
|
||||
|
||||
@ -304,7 +304,7 @@ class AppDslService:
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(session=self._session)
|
||||
draft_var_srv.delete_workflow_variables(app_id=app.id)
|
||||
draft_var_srv.delete_app_workflow_variables(app_id=app.id)
|
||||
return Import(
|
||||
id=import_id,
|
||||
status=status,
|
||||
|
||||
@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import Provider, ProviderCredential
|
||||
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
@ -569,6 +569,13 @@ class PluginService:
|
||||
)
|
||||
)
|
||||
|
||||
session.execute(
|
||||
delete(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Completed deleting credentials and cleaning provider associations for plugin: %s",
|
||||
plugin_id,
|
||||
|
||||
@ -472,6 +472,7 @@ class RagPipelineService:
|
||||
engine=db.engine,
|
||||
app_id=pipeline.id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
user_id=account.id,
|
||||
),
|
||||
),
|
||||
start_at=start_at,
|
||||
@ -1237,6 +1238,7 @@ class RagPipelineService:
|
||||
engine=db.engine,
|
||||
app_id=pipeline.id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
user_id=current_user.id,
|
||||
),
|
||||
),
|
||||
start_at=start_at,
|
||||
|
||||
@ -77,6 +77,7 @@ class DraftVarLoader(VariableLoader):
|
||||
_engine: Engine
|
||||
# Application ID for which variables are being loaded.
|
||||
_app_id: str
|
||||
_user_id: str
|
||||
_tenant_id: str
|
||||
_fallback_variables: Sequence[VariableBase]
|
||||
|
||||
@ -85,10 +86,12 @@ class DraftVarLoader(VariableLoader):
|
||||
engine: Engine,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
fallback_variables: Sequence[VariableBase] | None = None,
|
||||
):
|
||||
self._engine = engine
|
||||
self._app_id = app_id
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
self._fallback_variables = fallback_variables or []
|
||||
|
||||
@ -104,7 +107,7 @@ class DraftVarLoader(VariableLoader):
|
||||
|
||||
with Session(bind=self._engine, expire_on_commit=False) as session:
|
||||
srv = WorkflowDraftVariableService(session)
|
||||
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
|
||||
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors, user_id=self._user_id)
|
||||
|
||||
# Important:
|
||||
files: list[File] = []
|
||||
@ -218,6 +221,7 @@ class WorkflowDraftVariableService:
|
||||
self,
|
||||
app_id: str,
|
||||
selectors: Sequence[list[str]],
|
||||
user_id: str,
|
||||
) -> list[WorkflowDraftVariable]:
|
||||
"""
|
||||
Retrieve WorkflowDraftVariable instances based on app_id and selectors.
|
||||
@ -238,22 +242,30 @@ class WorkflowDraftVariableService:
|
||||
# Alternatively, a `SELECT` statement could be constructed for each selector and
|
||||
# combined using `UNION` to fetch all rows.
|
||||
# Benchmarking indicates that both approaches yield comparable performance.
|
||||
variables = (
|
||||
query = (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.options(
|
||||
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
|
||||
WorkflowDraftVariableFile.upload_file
|
||||
)
|
||||
)
|
||||
.where(WorkflowDraftVariable.app_id == app_id, or_(*ors))
|
||||
.all()
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
or_(*ors),
|
||||
)
|
||||
)
|
||||
return variables
|
||||
return query.all()
|
||||
|
||||
def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList:
|
||||
criteria = WorkflowDraftVariable.app_id == app_id
|
||||
def list_variables_without_values(
|
||||
self, app_id: str, page: int, limit: int, user_id: str
|
||||
) -> WorkflowDraftVariableList:
|
||||
criteria = [
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
]
|
||||
total = None
|
||||
query = self._session.query(WorkflowDraftVariable).where(criteria)
|
||||
query = self._session.query(WorkflowDraftVariable).where(*criteria)
|
||||
if page == 1:
|
||||
total = query.count()
|
||||
variables = (
|
||||
@ -269,11 +281,12 @@ class WorkflowDraftVariableService:
|
||||
|
||||
return WorkflowDraftVariableList(variables=variables, total=total)
|
||||
|
||||
def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
|
||||
criteria = (
|
||||
def _list_node_variables(self, app_id: str, node_id: str, user_id: str) -> WorkflowDraftVariableList:
|
||||
criteria = [
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.node_id == node_id,
|
||||
)
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
]
|
||||
query = self._session.query(WorkflowDraftVariable).where(*criteria)
|
||||
variables = (
|
||||
query.options(orm.selectinload(WorkflowDraftVariable.variable_file))
|
||||
@ -282,36 +295,36 @@ class WorkflowDraftVariableService:
|
||||
)
|
||||
return WorkflowDraftVariableList(variables=variables)
|
||||
|
||||
def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
|
||||
return self._list_node_variables(app_id, node_id)
|
||||
def list_node_variables(self, app_id: str, node_id: str, user_id: str) -> WorkflowDraftVariableList:
|
||||
return self._list_node_variables(app_id, node_id, user_id=user_id)
|
||||
|
||||
def list_conversation_variables(self, app_id: str) -> WorkflowDraftVariableList:
|
||||
return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID)
|
||||
def list_conversation_variables(self, app_id: str, user_id: str) -> WorkflowDraftVariableList:
|
||||
return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID, user_id=user_id)
|
||||
|
||||
def list_system_variables(self, app_id: str) -> WorkflowDraftVariableList:
|
||||
return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID)
|
||||
def list_system_variables(self, app_id: str, user_id: str) -> WorkflowDraftVariableList:
|
||||
return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID, user_id=user_id)
|
||||
|
||||
def get_conversation_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
|
||||
return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name)
|
||||
def get_conversation_variable(self, app_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
|
||||
return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name, user_id=user_id)
|
||||
|
||||
def get_system_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
|
||||
return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name)
|
||||
def get_system_variable(self, app_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
|
||||
return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, user_id=user_id)
|
||||
|
||||
def get_node_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
|
||||
return self._get_variable(app_id, node_id, name)
|
||||
def get_node_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
|
||||
return self._get_variable(app_id, node_id, name, user_id=user_id)
|
||||
|
||||
def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
|
||||
variable = (
|
||||
def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
|
||||
return (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.node_id == node_id,
|
||||
WorkflowDraftVariable.name == name,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return variable
|
||||
|
||||
def update_variable(
|
||||
self,
|
||||
@ -462,7 +475,17 @@ class WorkflowDraftVariableService:
|
||||
self._session.delete(upload_file)
|
||||
self._session.delete(variable)
|
||||
|
||||
def delete_workflow_variables(self, app_id: str):
|
||||
def delete_user_workflow_variables(self, app_id: str, user_id: str):
|
||||
(
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
)
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
def delete_app_workflow_variables(self, app_id: str):
|
||||
(
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app_id)
|
||||
@ -501,28 +524,35 @@ class WorkflowDraftVariableService:
|
||||
self._session.delete(upload_file)
|
||||
self._session.delete(variable_file)
|
||||
|
||||
def delete_node_variables(self, app_id: str, node_id: str):
|
||||
return self._delete_node_variables(app_id, node_id)
|
||||
def delete_node_variables(self, app_id: str, node_id: str, user_id: str):
|
||||
return self._delete_node_variables(app_id, node_id, user_id=user_id)
|
||||
|
||||
def _delete_node_variables(self, app_id: str, node_id: str):
|
||||
self._session.query(WorkflowDraftVariable).where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.node_id == node_id,
|
||||
).delete()
|
||||
def _delete_node_variables(self, app_id: str, node_id: str, user_id: str):
|
||||
(
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == app_id,
|
||||
WorkflowDraftVariable.node_id == node_id,
|
||||
WorkflowDraftVariable.user_id == user_id,
|
||||
)
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
def _get_conversation_id_from_draft_variable(self, app_id: str) -> str | None:
|
||||
def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None:
|
||||
draft_var = self._get_variable(
|
||||
app_id=app_id,
|
||||
node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||
name=str(SystemVariableKey.CONVERSATION_ID),
|
||||
user_id=user_id,
|
||||
)
|
||||
if draft_var is None:
|
||||
return None
|
||||
segment = draft_var.get_value()
|
||||
if not isinstance(segment, StringSegment):
|
||||
logger.warning(
|
||||
"sys.conversation_id variable is not a string: app_id=%s, id=%s",
|
||||
"sys.conversation_id variable is not a string: app_id=%s, user_id=%s, id=%s",
|
||||
app_id,
|
||||
user_id,
|
||||
draft_var.id,
|
||||
)
|
||||
return None
|
||||
@ -543,7 +573,7 @@ class WorkflowDraftVariableService:
|
||||
|
||||
If no such conversation exists, a new conversation is created and its ID is returned.
|
||||
"""
|
||||
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id)
|
||||
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id)
|
||||
|
||||
if conv_id is not None:
|
||||
conversation = (
|
||||
@ -580,12 +610,13 @@ class WorkflowDraftVariableService:
|
||||
self._session.flush()
|
||||
return conversation.id
|
||||
|
||||
def prefill_conversation_variable_default_values(self, workflow: Workflow):
|
||||
def prefill_conversation_variable_default_values(self, workflow: Workflow, user_id: str):
|
||||
""""""
|
||||
draft_conv_vars: list[WorkflowDraftVariable] = []
|
||||
for conv_var in workflow.conversation_variables:
|
||||
draft_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=workflow.app_id,
|
||||
user_id=user_id,
|
||||
name=conv_var.name,
|
||||
value=conv_var,
|
||||
description=conv_var.description,
|
||||
@ -635,7 +666,7 @@ def _batch_upsert_draft_variable(
|
||||
stmt = pg_insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars])
|
||||
if policy == _UpsertPolicy.OVERWRITE:
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
|
||||
index_elements=WorkflowDraftVariable.unique_app_id_user_id_node_id_name(),
|
||||
set_={
|
||||
# Refresh creation timestamp to ensure updated variables
|
||||
# appear first in chronologically sorted result sets.
|
||||
@ -652,7 +683,9 @@ def _batch_upsert_draft_variable(
|
||||
},
|
||||
)
|
||||
elif policy == _UpsertPolicy.IGNORE:
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name())
|
||||
stmt = stmt.on_conflict_do_nothing(
|
||||
index_elements=WorkflowDraftVariable.unique_app_id_user_id_node_id_name()
|
||||
)
|
||||
else:
|
||||
stmt = mysql_insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) # type: ignore[assignment]
|
||||
if policy == _UpsertPolicy.OVERWRITE:
|
||||
@ -682,6 +715,7 @@ def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
|
||||
d: dict[str, Any] = {
|
||||
"id": model.id,
|
||||
"app_id": model.app_id,
|
||||
"user_id": model.user_id,
|
||||
"last_edited_at": None,
|
||||
"node_id": model.node_id,
|
||||
"name": model.name,
|
||||
@ -807,6 +841,7 @@ class DraftVariableSaver:
|
||||
def _create_dummy_output_variable(self):
|
||||
return WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._app_id,
|
||||
user_id=self._user.id,
|
||||
node_id=self._node_id,
|
||||
name=self._DUMMY_OUTPUT_IDENTITY,
|
||||
node_execution_id=self._node_execution_id,
|
||||
@ -842,6 +877,7 @@ class DraftVariableSaver:
|
||||
draft_vars.append(
|
||||
WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._app_id,
|
||||
user_id=self._user.id,
|
||||
name=item.name,
|
||||
value=segment,
|
||||
)
|
||||
@ -862,6 +898,7 @@ class DraftVariableSaver:
|
||||
draft_vars.append(
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._app_id,
|
||||
user_id=self._user.id,
|
||||
node_id=self._node_id,
|
||||
name=name,
|
||||
node_execution_id=self._node_execution_id,
|
||||
@ -884,6 +921,7 @@ class DraftVariableSaver:
|
||||
draft_vars.append(
|
||||
WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._app_id,
|
||||
user_id=self._user.id,
|
||||
name=name,
|
||||
node_execution_id=self._node_execution_id,
|
||||
value=value_seg,
|
||||
@ -1019,6 +1057,7 @@ class DraftVariableSaver:
|
||||
# Create the draft variable
|
||||
draft_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._app_id,
|
||||
user_id=self._user.id,
|
||||
node_id=self._node_id,
|
||||
name=name,
|
||||
node_execution_id=self._node_execution_id,
|
||||
@ -1032,6 +1071,7 @@ class DraftVariableSaver:
|
||||
# Create the draft variable
|
||||
draft_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._app_id,
|
||||
user_id=self._user.id,
|
||||
node_id=self._node_id,
|
||||
name=name,
|
||||
node_execution_id=self._node_execution_id,
|
||||
|
||||
@ -697,7 +697,7 @@ class WorkflowService:
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=account.id)
|
||||
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id)
|
||||
node_type = Workflow.get_node_type_from_node_config(node_config)
|
||||
@ -740,6 +740,7 @@ class WorkflowService:
|
||||
engine=db.engine,
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=account.id,
|
||||
)
|
||||
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
@ -831,6 +832,7 @@ class WorkflowService:
|
||||
workflow=draft_workflow,
|
||||
node_config=node_config,
|
||||
manual_inputs=inputs or {},
|
||||
user_id=account.id,
|
||||
)
|
||||
node = self._build_human_input_node(
|
||||
workflow=draft_workflow,
|
||||
@ -891,6 +893,7 @@ class WorkflowService:
|
||||
workflow=draft_workflow,
|
||||
node_config=node_config,
|
||||
manual_inputs=inputs or {},
|
||||
user_id=account.id,
|
||||
)
|
||||
node = self._build_human_input_node(
|
||||
workflow=draft_workflow,
|
||||
@ -967,6 +970,7 @@ class WorkflowService:
|
||||
workflow=draft_workflow,
|
||||
node_config=node_config,
|
||||
manual_inputs=inputs or {},
|
||||
user_id=account.id,
|
||||
)
|
||||
node = self._build_human_input_node(
|
||||
workflow=draft_workflow,
|
||||
@ -1102,10 +1106,11 @@ class WorkflowService:
|
||||
workflow: Workflow,
|
||||
node_config: NodeConfigDict,
|
||||
manual_inputs: Mapping[str, Any],
|
||||
user_id: str,
|
||||
) -> VariablePool:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
@ -1118,6 +1123,7 @@ class WorkflowService:
|
||||
engine=db.engine,
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
variable_mapping = HumanInputNode.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
|
||||
@ -30,6 +30,7 @@ from services.workflow_draft_variable_service import (
|
||||
class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
_test_app_id: str
|
||||
_session: Session
|
||||
_test_user_id: str
|
||||
_node1_id = "test_node_1"
|
||||
_node2_id = "test_node_2"
|
||||
_node_exec_id = str(uuid.uuid4())
|
||||
@ -99,13 +100,13 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
|
||||
def test_list_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2)
|
||||
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2, user_id=self._test_user_id)
|
||||
assert var_list.total == 5
|
||||
assert len(var_list.variables) == 2
|
||||
page1_var_ids = {v.id for v in var_list.variables}
|
||||
assert page1_var_ids.issubset(self._variable_ids)
|
||||
|
||||
var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2)
|
||||
var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2, user_id=self._test_user_id)
|
||||
assert var_list_2.total is None
|
||||
assert len(var_list_2.variables) == 2
|
||||
page2_var_ids = {v.id for v in var_list_2.variables}
|
||||
@ -114,7 +115,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
|
||||
def test_get_node_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
|
||||
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var", user_id=self._test_user_id)
|
||||
assert node_var is not None
|
||||
assert node_var.id == self._node1_str_var_id
|
||||
assert node_var.name == "str_var"
|
||||
@ -122,7 +123,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
|
||||
def test_get_system_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
sys_var = srv.get_system_variable(self._test_app_id, "sys_var")
|
||||
sys_var = srv.get_system_variable(self._test_app_id, "sys_var", user_id=self._test_user_id)
|
||||
assert sys_var is not None
|
||||
assert sys_var.id == self._sys_var_id
|
||||
assert sys_var.name == "sys_var"
|
||||
@ -130,7 +131,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
|
||||
def test_get_conversation_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var")
|
||||
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var", user_id=self._test_user_id)
|
||||
assert conv_var is not None
|
||||
assert conv_var.id == self._conv_var_id
|
||||
assert conv_var.name == "conv_var"
|
||||
@ -138,7 +139,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
|
||||
def test_delete_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
srv.delete_node_variables(self._test_app_id, self._node2_id)
|
||||
srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id)
|
||||
node2_var_count = (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
.where(
|
||||
@ -162,7 +163,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
|
||||
def test__list_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
|
||||
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id)
|
||||
assert len(node_vars.variables) == 2
|
||||
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)
|
||||
|
||||
@ -173,7 +174,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
[self._node2_id, "str_var"],
|
||||
[self._node2_id, "int_var"],
|
||||
]
|
||||
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
|
||||
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors, user_id=self._test_user_id)
|
||||
assert len(variables) == 3
|
||||
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)
|
||||
|
||||
@ -206,19 +207,23 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._test_tenant_id = str(uuid.uuid4())
|
||||
self._test_user_id = str(uuid.uuid4())
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
name="conv_var",
|
||||
value=build_segment("conv_value"),
|
||||
)
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id=self._node1_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
@ -248,12 +253,22 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
session.commit()
|
||||
|
||||
def test_variable_loader_with_empty_selector(self):
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=self._test_app_id,
|
||||
tenant_id=self._test_tenant_id,
|
||||
user_id=self._test_user_id,
|
||||
)
|
||||
variables = var_loader.load_variables([])
|
||||
assert len(variables) == 0
|
||||
|
||||
def test_variable_loader_with_non_empty_selector(self):
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=self._test_app_id,
|
||||
tenant_id=self._test_tenant_id,
|
||||
user_id=self._test_user_id,
|
||||
)
|
||||
variables = var_loader.load_variables(
|
||||
[
|
||||
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
|
||||
@ -296,7 +311,12 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
session.commit()
|
||||
|
||||
# Now test loading using DraftVarLoader
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=self._test_app_id,
|
||||
tenant_id=self._test_tenant_id,
|
||||
user_id=setup_account.id,
|
||||
)
|
||||
|
||||
# Load the variable using the standard workflow
|
||||
variables = var_loader.load_variables([["test_offload_node", "offloaded_string_var"]])
|
||||
@ -313,7 +333,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Clean up - delete all draft variables for this app
|
||||
with Session(bind=db.engine) as session:
|
||||
service = WorkflowDraftVariableService(session)
|
||||
service.delete_workflow_variables(self._test_app_id)
|
||||
service.delete_app_workflow_variables(self._test_app_id)
|
||||
session.commit()
|
||||
|
||||
def test_load_offloaded_variable_object_type_integration(self):
|
||||
@ -364,6 +384,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Now create the offloaded draft variable with the correct file_id
|
||||
offloaded_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id="test_offload_node",
|
||||
name="offloaded_object_var",
|
||||
value=build_segment({"truncated": True}),
|
||||
@ -379,7 +400,9 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Use the service method that properly preloads relationships
|
||||
service = WorkflowDraftVariableService(session)
|
||||
draft_vars = service.get_draft_variables_by_selectors(
|
||||
self._test_app_id, [["test_offload_node", "offloaded_object_var"]]
|
||||
self._test_app_id,
|
||||
[["test_offload_node", "offloaded_object_var"]],
|
||||
user_id=self._test_user_id,
|
||||
)
|
||||
|
||||
assert len(draft_vars) == 1
|
||||
@ -387,7 +410,12 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
assert loaded_var.is_truncated()
|
||||
|
||||
# Create DraftVarLoader and test loading
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=self._test_app_id,
|
||||
tenant_id=self._test_tenant_id,
|
||||
user_id=self._test_user_id,
|
||||
)
|
||||
|
||||
# Test the _load_offloaded_variable method
|
||||
selector_tuple, variable = var_loader._load_offloaded_variable(loaded_var)
|
||||
@ -459,6 +487,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Now create the offloaded draft variable with the correct file_id
|
||||
offloaded_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id="test_integration_node",
|
||||
name="offloaded_integration_var",
|
||||
value=build_segment("truncated"),
|
||||
@ -473,7 +502,12 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
|
||||
# Test load_variables with both regular and offloaded variables
|
||||
# This method should handle the relationship preloading internally
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=self._test_app_id,
|
||||
tenant_id=self._test_tenant_id,
|
||||
user_id=self._test_user_id,
|
||||
)
|
||||
|
||||
variables = var_loader.load_variables(
|
||||
[
|
||||
@ -572,6 +606,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
# Create test variables
|
||||
self._node_var_with_exec = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id=self._node_id,
|
||||
name="test_var",
|
||||
value=build_segment("old_value"),
|
||||
@ -581,6 +616,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
|
||||
self._node_var_without_exec = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id=self._node_id,
|
||||
name="no_exec_var",
|
||||
value=build_segment("some_value"),
|
||||
@ -591,6 +627,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
|
||||
self._node_var_missing_exec = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id=self._node_id,
|
||||
name="missing_exec_var",
|
||||
value=build_segment("some_value"),
|
||||
@ -599,6 +636,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
|
||||
self._conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
name="conv_var_1",
|
||||
value=build_segment("old_conv_value"),
|
||||
)
|
||||
@ -764,6 +802,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
# Create a system variable
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
|
||||
@ -122,6 +122,7 @@ class TestWorkflowDraftVariableService:
|
||||
name,
|
||||
value,
|
||||
variable_type: DraftVariableType = DraftVariableType.CONVERSATION,
|
||||
user_id: str | None = None,
|
||||
fake=None,
|
||||
):
|
||||
"""
|
||||
@ -144,10 +145,15 @@ class TestWorkflowDraftVariableService:
|
||||
WorkflowDraftVariable: Created test variable instance with proper type configuration
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
if user_id is None:
|
||||
app = db_session_with_containers.query(App).filter_by(id=app_id).first()
|
||||
assert app is not None
|
||||
user_id = app.created_by
|
||||
if variable_type == "conversation":
|
||||
# Create conversation variable using the appropriate factory method
|
||||
variable = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
value=value,
|
||||
description=fake.text(max_nb_chars=20),
|
||||
@ -156,6 +162,7 @@ class TestWorkflowDraftVariableService:
|
||||
# Create system variable with editable flag and execution context
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
value=value,
|
||||
node_execution_id=fake.uuid4(),
|
||||
@ -165,6 +172,7 @@ class TestWorkflowDraftVariableService:
|
||||
# Create node variable with visibility and editability settings
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
value=value,
|
||||
@ -189,7 +197,13 @@ class TestWorkflowDraftVariableService:
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
test_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
"test_var",
|
||||
test_value,
|
||||
user_id=app.created_by,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variable = service.get_variable(variable.id)
|
||||
@ -250,7 +264,7 @@ class TestWorkflowDraftVariableService:
|
||||
["test_node_1", "var3"],
|
||||
]
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors)
|
||||
retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors, user_id=app.created_by)
|
||||
assert len(retrieved_variables) == 3
|
||||
var_names = [var.name for var in retrieved_variables]
|
||||
assert "var1" in var_names
|
||||
@ -288,7 +302,7 @@ class TestWorkflowDraftVariableService:
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_variables_without_values(app.id, page=1, limit=3)
|
||||
result = service.list_variables_without_values(app.id, page=1, limit=3, user_id=app.created_by)
|
||||
assert result.total == 5
|
||||
assert len(result.variables) == 3
|
||||
assert result.variables[0].created_at >= result.variables[1].created_at
|
||||
@ -339,7 +353,7 @@ class TestWorkflowDraftVariableService:
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_node_variables(app.id, node_id)
|
||||
result = service.list_node_variables(app.id, node_id, user_id=app.created_by)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == node_id
|
||||
@ -381,7 +395,7 @@ class TestWorkflowDraftVariableService:
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_conversation_variables(app.id)
|
||||
result = service.list_conversation_variables(app.id, user_id=app.created_by)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
@ -559,7 +573,7 @@ class TestWorkflowDraftVariableService:
|
||||
assert len(app_variables) == 3
|
||||
assert len(other_app_variables) == 1
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_workflow_variables(app.id)
|
||||
service.delete_user_workflow_variables(app.id, user_id=app.created_by)
|
||||
app_variables_after = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables_after = (
|
||||
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
@ -567,6 +581,69 @@ class TestWorkflowDraftVariableService:
|
||||
assert len(app_variables_after) == 0
|
||||
assert len(other_app_variables_after) == 1
|
||||
|
||||
def test_draft_variables_are_isolated_between_users(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test draft variable isolation for different users in the same app.
|
||||
|
||||
This test verifies that:
|
||||
1. Query APIs return only variables owned by the target user.
|
||||
2. User-scoped deletion only removes variables for that user and keeps
|
||||
other users' variables in the same app untouched.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
user_a = app.created_by
|
||||
user_b = fake.uuid4()
|
||||
|
||||
# Use identical variable names on purpose to verify uniqueness scope includes user_id.
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
"shared_name",
|
||||
StringSegment(value="value_a"),
|
||||
user_id=user_a,
|
||||
fake=fake,
|
||||
)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
"shared_name",
|
||||
StringSegment(value="value_b"),
|
||||
user_id=user_b,
|
||||
fake=fake,
|
||||
)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
"only_a",
|
||||
StringSegment(value="only_a"),
|
||||
user_id=user_a,
|
||||
fake=fake,
|
||||
)
|
||||
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
|
||||
user_a_vars = service.list_conversation_variables(app.id, user_id=user_a)
|
||||
user_b_vars = service.list_conversation_variables(app.id, user_id=user_b)
|
||||
assert {v.name for v in user_a_vars.variables} == {"shared_name", "only_a"}
|
||||
assert {v.name for v in user_b_vars.variables} == {"shared_name"}
|
||||
|
||||
service.delete_user_workflow_variables(app.id, user_id=user_a)
|
||||
|
||||
user_a_remaining = (
|
||||
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, user_id=user_a).count()
|
||||
)
|
||||
user_b_remaining = (
|
||||
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, user_id=user_b).count()
|
||||
)
|
||||
assert user_a_remaining == 0
|
||||
assert user_b_remaining == 1
|
||||
|
||||
def test_delete_node_variables_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
@ -627,7 +704,7 @@ class TestWorkflowDraftVariableService:
|
||||
assert len(other_node_variables) == 1
|
||||
assert len(conv_variables) == 1
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_node_variables(app.id, node_id)
|
||||
service.delete_node_variables(app.id, node_id, user_id=app.created_by)
|
||||
target_node_variables_after = (
|
||||
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
)
|
||||
@ -675,7 +752,7 @@ class TestWorkflowDraftVariableService:
|
||||
|
||||
db_session_with_containers.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.prefill_conversation_variable_default_values(workflow)
|
||||
service.prefill_conversation_variable_default_values(workflow, user_id="00000000-0000-0000-0000-000000000001")
|
||||
draft_variables = (
|
||||
db_session_with_containers.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
@ -715,7 +792,7 @@ class TestWorkflowDraftVariableService:
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id, app.created_by)
|
||||
assert retrieved_conv_id == conversation_id
|
||||
|
||||
def test_get_conversation_id_from_draft_variable_not_found(
|
||||
@ -731,7 +808,7 @@ class TestWorkflowDraftVariableService:
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id, app.created_by)
|
||||
assert retrieved_conv_id is None
|
||||
|
||||
def test_list_system_variables_success(
|
||||
@ -772,7 +849,7 @@ class TestWorkflowDraftVariableService:
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_system_variables(app.id)
|
||||
result = service.list_system_variables(app.id, user_id=app.created_by)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == SYSTEM_VARIABLE_NODE_ID
|
||||
@ -819,15 +896,15 @@ class TestWorkflowDraftVariableService:
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var")
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var", user_id=app.created_by)
|
||||
assert retrieved_conv_var is not None
|
||||
assert retrieved_conv_var.name == "test_conv_var"
|
||||
assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var")
|
||||
retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var", user_id=app.created_by)
|
||||
assert retrieved_sys_var is not None
|
||||
assert retrieved_sys_var.name == "test_sys_var"
|
||||
assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID
|
||||
retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var")
|
||||
retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var", user_id=app.created_by)
|
||||
assert retrieved_node_var is not None
|
||||
assert retrieved_node_var.name == "test_node_var"
|
||||
assert retrieved_node_var.node_id == "test_node"
|
||||
@ -845,9 +922,14 @@ class TestWorkflowDraftVariableService:
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var")
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var", user_id=app.created_by)
|
||||
assert retrieved_conv_var is None
|
||||
retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var")
|
||||
retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var", user_id=app.created_by)
|
||||
assert retrieved_sys_var is None
|
||||
retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var")
|
||||
retrieved_node_var = service.get_node_variable(
|
||||
app.id,
|
||||
"test_node",
|
||||
"non_existent_node_var",
|
||||
user_id=app.created_by,
|
||||
)
|
||||
assert retrieved_node_var is None
|
||||
|
||||
@ -398,6 +398,7 @@ class TestWorkflowDraftVariableEndpoints:
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1"))
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
|
||||
@ -234,6 +234,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
captured: dict[str, object] = {}
|
||||
prefill_calls: list[object] = []
|
||||
var_loader = SimpleNamespace(loader="draft")
|
||||
workflow = SimpleNamespace(id="workflow-id")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config",
|
||||
@ -260,8 +261,8 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
def __init__(self, session):
|
||||
_ = session
|
||||
|
||||
def prefill_conversation_variable_default_values(self, workflow):
|
||||
prefill_calls.append(workflow)
|
||||
def prefill_conversation_variable_default_values(self, workflow, user_id):
|
||||
prefill_calls.append((workflow, user_id))
|
||||
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService)
|
||||
|
||||
@ -273,7 +274,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
|
||||
result = generator.single_iteration_generate(
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
|
||||
workflow=SimpleNamespace(id="workflow-id"),
|
||||
workflow=workflow,
|
||||
node_id="node-1",
|
||||
user=SimpleNamespace(id="user-id"),
|
||||
args={"inputs": {"foo": "bar"}},
|
||||
@ -281,7 +282,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert prefill_calls
|
||||
assert prefill_calls == [(workflow, "user-id")]
|
||||
assert captured["variable_loader"] is var_loader
|
||||
assert captured["application_generate_entity"].single_iteration_run.node_id == "node-1"
|
||||
|
||||
@ -291,6 +292,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
captured: dict[str, object] = {}
|
||||
prefill_calls: list[object] = []
|
||||
var_loader = SimpleNamespace(loader="draft")
|
||||
workflow = SimpleNamespace(id="workflow-id")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config",
|
||||
@ -317,8 +319,8 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
def __init__(self, session):
|
||||
_ = session
|
||||
|
||||
def prefill_conversation_variable_default_values(self, workflow):
|
||||
prefill_calls.append(workflow)
|
||||
def prefill_conversation_variable_default_values(self, workflow, user_id):
|
||||
prefill_calls.append((workflow, user_id))
|
||||
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService)
|
||||
|
||||
@ -330,7 +332,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
|
||||
result = generator.single_loop_generate(
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
|
||||
workflow=SimpleNamespace(id="workflow-id"),
|
||||
workflow=workflow,
|
||||
node_id="node-2",
|
||||
user=SimpleNamespace(id="user-id"),
|
||||
args=SimpleNamespace(inputs={"foo": "bar"}),
|
||||
@ -338,7 +340,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert prefill_calls
|
||||
assert prefill_calls == [(workflow, "user-id")]
|
||||
assert captured["variable_loader"] is var_loader
|
||||
assert captured["application_generate_entity"].single_loop_run.node_id == "node-2"
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
from core.rag.models.document import Document
|
||||
|
||||
@ -19,6 +20,7 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
"""Tests for WeaviateVector class with focus on doc_type metadata handling."""
|
||||
|
||||
def setUp(self):
|
||||
weaviate_vector_module._weaviate_client = None
|
||||
self.config = WeaviateConfig(
|
||||
endpoint="http://localhost:8080",
|
||||
api_key="test-key",
|
||||
@ -27,6 +29,9 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
self.collection_name = "Test_Collection_Node"
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
|
||||
|
||||
def tearDown(self):
|
||||
weaviate_vector_module._weaviate_client = None
|
||||
|
||||
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
|
||||
def _create_weaviate_vector(self, mock_weaviate_module):
|
||||
"""Helper to create a WeaviateVector instance with mocked client."""
|
||||
|
||||
@ -263,7 +263,7 @@ def test_import_app_completed_uses_declared_dependencies(monkeypatch):
|
||||
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
assert result.app_id == "app-new"
|
||||
draft_var_service.delete_workflow_variables.assert_called_once_with(app_id="app-new")
|
||||
draft_var_service.delete_app_workflow_variables.assert_called_once_with(app_id="app-new")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_workflow", [True, False])
|
||||
@ -305,7 +305,7 @@ def test_import_app_legacy_versions_extract_dependencies(monkeypatch, has_workfl
|
||||
account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=_yaml_dump(data)
|
||||
)
|
||||
assert result.status == ImportStatus.COMPLETED_WITH_WARNINGS
|
||||
draft_var_service.delete_workflow_variables.assert_called_once_with(app_id="app-legacy")
|
||||
draft_var_service.delete_app_workflow_variables.assert_called_once_with(app_id="app-legacy")
|
||||
|
||||
|
||||
def test_import_app_yaml_error_returns_failed(monkeypatch):
|
||||
|
||||
@ -24,7 +24,11 @@ class TestDraftVarLoaderSimple:
|
||||
def draft_var_loader(self, mock_engine):
|
||||
"""Create DraftVarLoader instance for testing."""
|
||||
return DraftVarLoader(
|
||||
engine=mock_engine, app_id="test-app-id", tenant_id="test-tenant-id", fallback_variables=[]
|
||||
engine=mock_engine,
|
||||
app_id="test-app-id",
|
||||
tenant_id="test-tenant-id",
|
||||
user_id="test-user-id",
|
||||
fallback_variables=[],
|
||||
)
|
||||
|
||||
def test_load_offloaded_variable_string_type_unit(self, draft_var_loader):
|
||||
@ -323,7 +327,9 @@ class TestDraftVarLoaderSimple:
|
||||
|
||||
# Verify service method was called
|
||||
mock_service.get_draft_variables_by_selectors.assert_called_once_with(
|
||||
draft_var_loader._app_id, selectors
|
||||
draft_var_loader._app_id,
|
||||
selectors,
|
||||
user_id=draft_var_loader._user_id,
|
||||
)
|
||||
|
||||
# Verify offloaded variable loading was called
|
||||
|
||||
@ -8,7 +8,7 @@ from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from libs.uuid_utils import uuidv7
|
||||
@ -182,6 +182,42 @@ class TestDraftVariableSaver:
|
||||
draft_vars = mock_batch_upsert.call_args[0][1]
|
||||
assert len(draft_vars) == 2
|
||||
|
||||
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True)
|
||||
def test_start_node_save_persists_sys_timestamp_and_workflow_run_id(self, mock_batch_upsert):
|
||||
"""Start node should persist common `sys.*` variables, not only `sys.files`."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.tenant_id = "test-tenant-id"
|
||||
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id="test-app-id",
|
||||
node_id="start-node-id",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
node_execution_id="exec-id",
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
outputs = {
|
||||
f"{SYSTEM_VARIABLE_NODE_ID}.{SystemVariableKey.TIMESTAMP}": 1700000000,
|
||||
f"{SYSTEM_VARIABLE_NODE_ID}.{SystemVariableKey.WORKFLOW_EXECUTION_ID}": "run-id-123",
|
||||
}
|
||||
|
||||
saver.save(outputs=outputs)
|
||||
|
||||
mock_batch_upsert.assert_called_once()
|
||||
draft_vars = mock_batch_upsert.call_args[0][1]
|
||||
|
||||
# plus one dummy output because there are no non-sys Start inputs
|
||||
assert len(draft_vars) == 3
|
||||
|
||||
sys_vars = [v for v in draft_vars if v.node_id == SYSTEM_VARIABLE_NODE_ID]
|
||||
assert {v.name for v in sys_vars} == {
|
||||
str(SystemVariableKey.TIMESTAMP),
|
||||
str(SystemVariableKey.WORKFLOW_EXECUTION_ID),
|
||||
}
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
def _get_test_app_id(self):
|
||||
|
||||
@ -245,6 +245,7 @@ class TestWorkflowService:
|
||||
workflow=workflow,
|
||||
node_config=node_config,
|
||||
manual_inputs={"#node-0.result#": "LLM output"},
|
||||
user_id="account-1",
|
||||
)
|
||||
|
||||
node.render_form_content_with_outputs.assert_called_once()
|
||||
|
||||
139
web/__tests__/check-components-diff-coverage.test.ts
Normal file
139
web/__tests__/check-components-diff-coverage.test.ts
Normal file
@ -0,0 +1,139 @@
|
||||
import {
|
||||
getChangedBranchCoverage,
|
||||
getChangedStatementCoverage,
|
||||
getIgnoredChangedLinesFromSource,
|
||||
normalizeToRepoRelative,
|
||||
parseChangedLineMap,
|
||||
} from '../scripts/check-components-diff-coverage-lib.mjs'
|
||||
|
||||
describe('check-components-diff-coverage helpers', () => {
|
||||
it('should parse changed line maps from unified diffs', () => {
|
||||
const diff = [
|
||||
'diff --git a/web/app/components/share/a.ts b/web/app/components/share/a.ts',
|
||||
'+++ b/web/app/components/share/a.ts',
|
||||
'@@ -10,0 +11,2 @@',
|
||||
'+const a = 1',
|
||||
'+const b = 2',
|
||||
'diff --git a/web/app/components/base/b.ts b/web/app/components/base/b.ts',
|
||||
'+++ b/web/app/components/base/b.ts',
|
||||
'@@ -20 +21 @@',
|
||||
'+const c = 3',
|
||||
'diff --git a/web/README.md b/web/README.md',
|
||||
'+++ b/web/README.md',
|
||||
'@@ -1 +1 @@',
|
||||
'+ignore me',
|
||||
].join('\n')
|
||||
|
||||
const lineMap = parseChangedLineMap(diff, (filePath: string) => filePath.startsWith('web/app/components/'))
|
||||
|
||||
expect([...lineMap.entries()]).toEqual([
|
||||
['web/app/components/share/a.ts', new Set([11, 12])],
|
||||
['web/app/components/base/b.ts', new Set([21])],
|
||||
])
|
||||
})
|
||||
|
||||
it('should normalize coverage and absolute paths to repo-relative paths', () => {
|
||||
const repoRoot = '/repo'
|
||||
const webRoot = '/repo/web'
|
||||
|
||||
expect(normalizeToRepoRelative('web/app/components/share/a.ts', {
|
||||
appComponentsCoveragePrefix: 'app/components/',
|
||||
appComponentsPrefix: 'web/app/components/',
|
||||
repoRoot,
|
||||
sharedTestPrefix: 'web/__tests__/',
|
||||
webRoot,
|
||||
})).toBe('web/app/components/share/a.ts')
|
||||
|
||||
expect(normalizeToRepoRelative('app/components/share/a.ts', {
|
||||
appComponentsCoveragePrefix: 'app/components/',
|
||||
appComponentsPrefix: 'web/app/components/',
|
||||
repoRoot,
|
||||
sharedTestPrefix: 'web/__tests__/',
|
||||
webRoot,
|
||||
})).toBe('web/app/components/share/a.ts')
|
||||
|
||||
expect(normalizeToRepoRelative('/repo/web/app/components/share/a.ts', {
|
||||
appComponentsCoveragePrefix: 'app/components/',
|
||||
appComponentsPrefix: 'web/app/components/',
|
||||
repoRoot,
|
||||
sharedTestPrefix: 'web/__tests__/',
|
||||
webRoot,
|
||||
})).toBe('web/app/components/share/a.ts')
|
||||
})
|
||||
|
||||
it('should calculate changed statement coverage from changed lines', () => {
|
||||
const entry = {
|
||||
s: { 0: 1, 1: 0 },
|
||||
statementMap: {
|
||||
0: { start: { line: 10 }, end: { line: 10 } },
|
||||
1: { start: { line: 12 }, end: { line: 13 } },
|
||||
},
|
||||
}
|
||||
|
||||
const coverage = getChangedStatementCoverage(entry, new Set([10, 12]))
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 1,
|
||||
total: 2,
|
||||
uncoveredLines: [12],
|
||||
})
|
||||
})
|
||||
|
||||
it('should fail changed lines when a source file has no coverage entry', () => {
|
||||
const coverage = getChangedStatementCoverage(undefined, new Set([42, 43]))
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 0,
|
||||
total: 2,
|
||||
uncoveredLines: [42, 43],
|
||||
})
|
||||
})
|
||||
|
||||
it('should calculate changed branch coverage using changed branch definitions', () => {
|
||||
const entry = {
|
||||
b: {
|
||||
0: [1, 0],
|
||||
},
|
||||
branchMap: {
|
||||
0: {
|
||||
line: 20,
|
||||
loc: { start: { line: 20 }, end: { line: 20 } },
|
||||
locations: [
|
||||
{ start: { line: 20 }, end: { line: 20 } },
|
||||
{ start: { line: 21 }, end: { line: 21 } },
|
||||
],
|
||||
type: 'if',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const coverage = getChangedBranchCoverage(entry, new Set([20]))
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 1,
|
||||
total: 2,
|
||||
uncoveredBranches: [
|
||||
{ armIndex: 1, line: 21 },
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
it('should ignore changed lines with valid pragma reasons and report invalid pragmas', () => {
|
||||
const sourceCode = [
|
||||
'const a = 1',
|
||||
'const b = 2 // diff-coverage-ignore-line: defensive fallback',
|
||||
'const c = 3 // diff-coverage-ignore-line:',
|
||||
'const d = 4 // diff-coverage-ignore-line: not changed',
|
||||
].join('\n')
|
||||
|
||||
const result = getIgnoredChangedLinesFromSource(sourceCode, new Set([2, 3]))
|
||||
|
||||
expect([...result.effectiveChangedLines]).toEqual([3])
|
||||
expect([...result.ignoredLines.entries()]).toEqual([
|
||||
[2, 'defensive fallback'],
|
||||
])
|
||||
expect(result.invalidPragmas).toEqual([
|
||||
{ line: 3, reason: 'missing ignore reason' },
|
||||
])
|
||||
})
|
||||
})
|
||||
@ -275,7 +275,7 @@ describe('useTextGenerationBatch', () => {
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.handleCompleted({ answer: 'failed' } as unknown as string, 1, false)
|
||||
result.current.handleCompleted('{"answer":"failed"}', 1, false)
|
||||
})
|
||||
|
||||
expect(result.current.allFailedTaskList).toEqual([
|
||||
@ -291,7 +291,7 @@ describe('useTextGenerationBatch', () => {
|
||||
{
|
||||
'Name': 'Alice',
|
||||
'Score': '',
|
||||
'generation.completionResult': JSON.stringify({ answer: 'failed' }),
|
||||
'generation.completionResult': '{"answer":"failed"}',
|
||||
},
|
||||
])
|
||||
|
||||
|
||||
@ -241,10 +241,7 @@ export const useTextGenerationBatch = ({
|
||||
result[variable.name] = String(task.params.inputs[variable.key] ?? '')
|
||||
})
|
||||
|
||||
let completionValue = batchCompletionMap[String(task.id)]
|
||||
if (typeof completionValue === 'object')
|
||||
completionValue = JSON.stringify(completionValue)
|
||||
|
||||
const completionValue = batchCompletionMap[String(task.id)] ?? ''
|
||||
result[t('generation.completionResult', { ns: 'share' })] = completionValue
|
||||
return result
|
||||
})
|
||||
|
||||
@ -0,0 +1,334 @@
|
||||
import type { PromptConfig } from '@/models/debug'
|
||||
import type { SiteInfo } from '@/models/share'
|
||||
import type { IOtherOptions } from '@/service/base'
|
||||
import type { VisionSettings } from '@/types/app'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { AppSourceType } from '@/service/share'
|
||||
import { Resolution, TransferMethod } from '@/types/app'
|
||||
import Result from '../index'
|
||||
|
||||
const {
|
||||
notifyMock,
|
||||
sendCompletionMessageMock,
|
||||
sendWorkflowMessageMock,
|
||||
stopChatMessageRespondingMock,
|
||||
textGenerationResPropsSpy,
|
||||
} = vi.hoisted(() => ({
|
||||
notifyMock: vi.fn(),
|
||||
sendCompletionMessageMock: vi.fn(),
|
||||
sendWorkflowMessageMock: vi.fn(),
|
||||
stopChatMessageRespondingMock: vi.fn(),
|
||||
textGenerationResPropsSpy: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('i18next', () => ({
|
||||
t: (key: string) => key,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
notify: notifyMock,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/utils', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/utils')>('@/utils')
|
||||
return {
|
||||
...actual,
|
||||
sleep: () => new Promise<void>(() => {}),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/service/share', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/service/share')>('@/service/share')
|
||||
return {
|
||||
...actual,
|
||||
sendCompletionMessage: (...args: Parameters<typeof actual.sendCompletionMessage>) => sendCompletionMessageMock(...args),
|
||||
sendWorkflowMessage: (...args: Parameters<typeof actual.sendWorkflowMessage>) => sendWorkflowMessageMock(...args),
|
||||
stopChatMessageResponding: (...args: Parameters<typeof actual.stopChatMessageResponding>) => stopChatMessageRespondingMock(...args),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/app/components/app/text-generate/item', () => ({
|
||||
default: (props: Record<string, unknown>) => {
|
||||
textGenerationResPropsSpy(props)
|
||||
return (
|
||||
<div data-testid="text-generation-res">
|
||||
{typeof props.content === 'string' ? props.content : JSON.stringify(props.content ?? null)}
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/share/text-generation/no-data', () => ({
|
||||
default: () => <div data-testid="no-data">No data</div>,
|
||||
}))
|
||||
|
||||
const promptConfig: PromptConfig = {
|
||||
prompt_template: 'template',
|
||||
prompt_variables: [
|
||||
{ key: 'name', name: 'Name', type: 'string', required: true },
|
||||
],
|
||||
}
|
||||
|
||||
const siteInfo: SiteInfo = {
|
||||
title: 'Share title',
|
||||
description: 'Share description',
|
||||
icon_type: 'emoji',
|
||||
icon: 'robot',
|
||||
}
|
||||
|
||||
const visionConfig: VisionSettings = {
|
||||
enabled: false,
|
||||
number_limits: 2,
|
||||
detail: Resolution.low,
|
||||
transfer_methods: [TransferMethod.local_file],
|
||||
}
|
||||
|
||||
const baseProps = {
|
||||
appId: 'app-1',
|
||||
appSourceType: AppSourceType.webApp,
|
||||
completionFiles: [],
|
||||
controlRetry: 0,
|
||||
controlSend: 0,
|
||||
controlStopResponding: 0,
|
||||
handleSaveMessage: vi.fn(),
|
||||
inputs: { name: 'Alice' },
|
||||
isCallBatchAPI: false,
|
||||
isError: false,
|
||||
isMobile: false,
|
||||
isPC: true,
|
||||
isShowTextToSpeech: true,
|
||||
isWorkflow: false,
|
||||
moreLikeThisEnabled: true,
|
||||
onCompleted: vi.fn(),
|
||||
onRunControlChange: vi.fn(),
|
||||
onRunStart: vi.fn(),
|
||||
onShowRes: vi.fn(),
|
||||
promptConfig,
|
||||
siteInfo,
|
||||
visionConfig,
|
||||
}
|
||||
|
||||
describe('Result', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
stopChatMessageRespondingMock.mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
it('should render no data before the first execution', () => {
|
||||
render(<Result {...baseProps} />)
|
||||
|
||||
expect(screen.getByTestId('no-data')).toBeTruthy()
|
||||
expect(screen.queryByTestId('text-generation-res')).toBeNull()
|
||||
})
|
||||
|
||||
it('should stream completion results and stop the current task', async () => {
|
||||
let completionHandlers: {
|
||||
onCompleted: () => void
|
||||
onData: (chunk: string, isFirstMessage: boolean, info: { messageId: string, taskId?: string }) => void
|
||||
onError: () => void
|
||||
onMessageReplace: (messageReplace: { answer: string }) => void
|
||||
} | null = null
|
||||
|
||||
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
|
||||
completionHandlers = handlers
|
||||
})
|
||||
|
||||
const onCompleted = vi.fn()
|
||||
const onRunControlChange = vi.fn()
|
||||
const { rerender } = render(
|
||||
<Result
|
||||
{...baseProps}
|
||||
onCompleted={onCompleted}
|
||||
onRunControlChange={onRunControlChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
rerender(
|
||||
<Result
|
||||
{...baseProps}
|
||||
controlSend={1}
|
||||
onCompleted={onCompleted}
|
||||
onRunControlChange={onRunControlChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(sendCompletionMessageMock).toHaveBeenCalledTimes(1)
|
||||
expect(screen.getByRole('status', { name: 'appApi.loading' })).toBeTruthy()
|
||||
|
||||
await act(async () => {
|
||||
completionHandlers?.onData('Hello', false, {
|
||||
messageId: 'message-1',
|
||||
taskId: 'task-1',
|
||||
})
|
||||
})
|
||||
|
||||
expect(screen.getByTestId('text-generation-res').textContent).toContain('Hello')
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onRunControlChange).toHaveBeenLastCalledWith(expect.objectContaining({
|
||||
isStopping: false,
|
||||
}))
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'operation.stopResponding' }))
|
||||
await waitFor(() => {
|
||||
expect(stopChatMessageRespondingMock).toHaveBeenCalledWith('app-1', 'task-1', AppSourceType.webApp, 'app-1')
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
completionHandlers?.onCompleted()
|
||||
})
|
||||
|
||||
expect(onCompleted).toHaveBeenCalledWith('Hello', undefined, true)
|
||||
expect(textGenerationResPropsSpy).toHaveBeenLastCalledWith(expect.objectContaining({
|
||||
messageId: 'message-1',
|
||||
}))
|
||||
})
|
||||
|
||||
it('should render workflow results after workflow completion', async () => {
|
||||
let workflowHandlers: IOtherOptions | null = null
|
||||
sendWorkflowMessageMock.mockImplementation(async (_data, handlers) => {
|
||||
workflowHandlers = handlers
|
||||
})
|
||||
|
||||
const onCompleted = vi.fn()
|
||||
const { rerender } = render(
|
||||
<Result
|
||||
{...baseProps}
|
||||
isWorkflow
|
||||
onCompleted={onCompleted}
|
||||
/>,
|
||||
)
|
||||
|
||||
rerender(
|
||||
<Result
|
||||
{...baseProps}
|
||||
isWorkflow
|
||||
controlSend={1}
|
||||
onCompleted={onCompleted}
|
||||
/>,
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
workflowHandlers?.onWorkflowStarted?.({
|
||||
workflow_run_id: 'run-1',
|
||||
task_id: 'task-1',
|
||||
event: 'workflow_started',
|
||||
data: {
|
||||
id: 'run-1',
|
||||
workflow_id: 'wf-1',
|
||||
created_at: 0,
|
||||
},
|
||||
})
|
||||
workflowHandlers?.onTextChunk?.({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'text_chunk',
|
||||
data: {
|
||||
text: 'Hello',
|
||||
},
|
||||
})
|
||||
workflowHandlers?.onWorkflowFinished?.({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'workflow_finished',
|
||||
data: {
|
||||
id: 'run-1',
|
||||
workflow_id: 'wf-1',
|
||||
status: 'succeeded',
|
||||
outputs: {
|
||||
answer: 'Hello',
|
||||
},
|
||||
error: '',
|
||||
elapsed_time: 0,
|
||||
total_tokens: 0,
|
||||
total_steps: 0,
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(screen.getByTestId('text-generation-res').textContent).toContain('{"answer":"Hello"}')
|
||||
expect(textGenerationResPropsSpy).toHaveBeenLastCalledWith(expect.objectContaining({
|
||||
workflowProcessData: expect.objectContaining({
|
||||
resultText: 'Hello',
|
||||
status: 'succeeded',
|
||||
}),
|
||||
}))
|
||||
expect(onCompleted).toHaveBeenCalledWith('{"answer":"Hello"}', undefined, true)
|
||||
})
|
||||
|
||||
it('should render batch task ids for both short and long indexes', () => {
|
||||
const { rerender } = render(
|
||||
<Result
|
||||
{...baseProps}
|
||||
isCallBatchAPI
|
||||
taskId={3}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(textGenerationResPropsSpy).toHaveBeenLastCalledWith(expect.objectContaining({
|
||||
taskId: '03',
|
||||
}))
|
||||
|
||||
rerender(
|
||||
<Result
|
||||
{...baseProps}
|
||||
isCallBatchAPI
|
||||
taskId={12}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(textGenerationResPropsSpy).toHaveBeenLastCalledWith(expect.objectContaining({
|
||||
taskId: '12',
|
||||
}))
|
||||
})
|
||||
|
||||
it('should render the mobile stop button layout while a batch run is responding', async () => {
|
||||
let completionHandlers: {
|
||||
onData: (chunk: string, isFirstMessage: boolean, info: { messageId: string, taskId?: string }) => void
|
||||
} | null = null
|
||||
|
||||
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
|
||||
completionHandlers = handlers
|
||||
})
|
||||
|
||||
const { rerender } = render(
|
||||
<Result
|
||||
{...baseProps}
|
||||
isCallBatchAPI
|
||||
isMobile
|
||||
isPC={false}
|
||||
taskId={2}
|
||||
/>,
|
||||
)
|
||||
|
||||
rerender(
|
||||
<Result
|
||||
{...baseProps}
|
||||
controlSend={1}
|
||||
isCallBatchAPI
|
||||
isMobile
|
||||
isPC={false}
|
||||
taskId={2}
|
||||
/>,
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
completionHandlers?.onData('Hello', false, {
|
||||
messageId: 'message-batch',
|
||||
taskId: 'task-batch',
|
||||
})
|
||||
})
|
||||
|
||||
expect(screen.getByRole('button', { name: 'operation.stopResponding' }).parentElement?.className).toContain('justify-center')
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,293 @@
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import type { PromptConfig } from '@/models/debug'
|
||||
import type { VisionFile, VisionSettings } from '@/types/app'
|
||||
import { Resolution, TransferMethod } from '@/types/app'
|
||||
import { buildResultRequestData, validateResultRequest } from '../result-request'
|
||||
|
||||
const createTranslator = () => vi.fn((key: string) => key)
|
||||
|
||||
const createFileEntity = (overrides: Partial<FileEntity> = {}): FileEntity => ({
|
||||
id: 'file-1',
|
||||
name: 'example.txt',
|
||||
size: 128,
|
||||
type: 'text/plain',
|
||||
progress: 100,
|
||||
transferMethod: TransferMethod.local_file,
|
||||
supportFileType: 'document',
|
||||
uploadedId: 'uploaded-1',
|
||||
url: 'https://example.com/file.txt',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createVisionFile = (overrides: Partial<VisionFile> = {}): VisionFile => ({
|
||||
type: 'image',
|
||||
transfer_method: TransferMethod.local_file,
|
||||
upload_file_id: 'upload-1',
|
||||
url: 'https://example.com/image.png',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const promptConfig: PromptConfig = {
|
||||
prompt_template: 'template',
|
||||
prompt_variables: [
|
||||
{ key: 'name', name: 'Name', type: 'string', required: true },
|
||||
{ key: 'enabled', name: 'Enabled', type: 'boolean', required: true },
|
||||
{ key: 'file', name: 'File', type: 'file', required: false },
|
||||
{ key: 'files', name: 'Files', type: 'file-list', required: false },
|
||||
],
|
||||
}
|
||||
|
||||
const visionConfig: VisionSettings = {
|
||||
enabled: true,
|
||||
number_limits: 2,
|
||||
detail: Resolution.low,
|
||||
transfer_methods: [TransferMethod.local_file],
|
||||
}
|
||||
|
||||
describe('result-request', () => {
|
||||
it('should reject missing required non-boolean inputs', () => {
|
||||
const t = createTranslator()
|
||||
|
||||
const result = validateResultRequest({
|
||||
completionFiles: [],
|
||||
inputs: {
|
||||
enabled: false,
|
||||
},
|
||||
isCallBatchAPI: false,
|
||||
promptConfig,
|
||||
t,
|
||||
})
|
||||
|
||||
expect(result).toEqual({
|
||||
canSend: false,
|
||||
notification: {
|
||||
type: 'error',
|
||||
message: 'errorMessage.valueOfVarRequired',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should allow required number inputs with a value of zero', () => {
|
||||
const result = validateResultRequest({
|
||||
completionFiles: [],
|
||||
inputs: {
|
||||
count: 0,
|
||||
},
|
||||
isCallBatchAPI: false,
|
||||
promptConfig: {
|
||||
prompt_template: 'template',
|
||||
prompt_variables: [
|
||||
{ key: 'count', name: 'Count', type: 'number', required: true },
|
||||
],
|
||||
},
|
||||
t: createTranslator(),
|
||||
})
|
||||
|
||||
expect(result).toEqual({ canSend: true })
|
||||
})
|
||||
|
||||
it('should reject required text inputs that only contain whitespace', () => {
|
||||
const result = validateResultRequest({
|
||||
completionFiles: [],
|
||||
inputs: {
|
||||
name: ' ',
|
||||
},
|
||||
isCallBatchAPI: false,
|
||||
promptConfig: {
|
||||
prompt_template: 'template',
|
||||
prompt_variables: [
|
||||
{ key: 'name', name: 'Name', type: 'string', required: true },
|
||||
],
|
||||
},
|
||||
t: createTranslator(),
|
||||
})
|
||||
|
||||
expect(result).toEqual({
|
||||
canSend: false,
|
||||
notification: {
|
||||
type: 'error',
|
||||
message: 'errorMessage.valueOfVarRequired',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should reject required file lists when no files are selected', () => {
|
||||
const result = validateResultRequest({
|
||||
completionFiles: [],
|
||||
inputs: {
|
||||
files: [],
|
||||
},
|
||||
isCallBatchAPI: false,
|
||||
promptConfig: {
|
||||
prompt_template: 'template',
|
||||
prompt_variables: [
|
||||
{ key: 'files', name: 'Files', type: 'file-list', required: true },
|
||||
],
|
||||
},
|
||||
t: createTranslator(),
|
||||
})
|
||||
|
||||
expect(result).toEqual({
|
||||
canSend: false,
|
||||
notification: {
|
||||
type: 'error',
|
||||
message: 'errorMessage.valueOfVarRequired',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should allow required file inputs when a file is selected', () => {
|
||||
const result = validateResultRequest({
|
||||
completionFiles: [],
|
||||
inputs: {
|
||||
file: createFileEntity(),
|
||||
},
|
||||
isCallBatchAPI: false,
|
||||
promptConfig: {
|
||||
prompt_template: 'template',
|
||||
prompt_variables: [
|
||||
{ key: 'file', name: 'File', type: 'file', required: true },
|
||||
],
|
||||
},
|
||||
t: createTranslator(),
|
||||
})
|
||||
|
||||
expect(result).toEqual({ canSend: true })
|
||||
})
|
||||
|
||||
it('should reject pending local uploads outside batch mode', () => {
|
||||
const t = createTranslator()
|
||||
|
||||
const result = validateResultRequest({
|
||||
completionFiles: [
|
||||
createVisionFile({ upload_file_id: '' }),
|
||||
],
|
||||
inputs: {
|
||||
name: 'Alice',
|
||||
},
|
||||
isCallBatchAPI: false,
|
||||
promptConfig,
|
||||
t,
|
||||
})
|
||||
|
||||
expect(result).toEqual({
|
||||
canSend: false,
|
||||
notification: {
|
||||
type: 'info',
|
||||
message: 'errorMessage.waitForFileUpload',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle missing prompt metadata with and without pending uploads', () => {
|
||||
const t = createTranslator()
|
||||
|
||||
const blocked = validateResultRequest({
|
||||
completionFiles: [
|
||||
createVisionFile({ upload_file_id: '' }),
|
||||
],
|
||||
inputs: {},
|
||||
isCallBatchAPI: false,
|
||||
promptConfig: null,
|
||||
t,
|
||||
})
|
||||
|
||||
const allowed = validateResultRequest({
|
||||
completionFiles: [],
|
||||
inputs: {},
|
||||
isCallBatchAPI: false,
|
||||
promptConfig: null,
|
||||
t,
|
||||
})
|
||||
|
||||
expect(blocked).toEqual({
|
||||
canSend: false,
|
||||
notification: {
|
||||
type: 'info',
|
||||
message: 'errorMessage.waitForFileUpload',
|
||||
},
|
||||
})
|
||||
expect(allowed).toEqual({ canSend: true })
|
||||
})
|
||||
|
||||
it('should skip validation in batch mode', () => {
|
||||
const result = validateResultRequest({
|
||||
completionFiles: [
|
||||
createVisionFile({ upload_file_id: '' }),
|
||||
],
|
||||
inputs: {},
|
||||
isCallBatchAPI: true,
|
||||
promptConfig,
|
||||
t: createTranslator(),
|
||||
})
|
||||
|
||||
expect(result).toEqual({ canSend: true })
|
||||
})
|
||||
|
||||
it('should build request data for single and list file inputs', () => {
|
||||
const file = createFileEntity()
|
||||
const secondFile = createFileEntity({
|
||||
id: 'file-2',
|
||||
name: 'second.txt',
|
||||
uploadedId: 'uploaded-2',
|
||||
url: 'https://example.com/second.txt',
|
||||
})
|
||||
|
||||
const result = buildResultRequestData({
|
||||
completionFiles: [
|
||||
createVisionFile(),
|
||||
createVisionFile({
|
||||
transfer_method: TransferMethod.remote_url,
|
||||
upload_file_id: '',
|
||||
url: 'https://example.com/remote.png',
|
||||
}),
|
||||
],
|
||||
inputs: {
|
||||
enabled: true,
|
||||
file,
|
||||
files: [file, secondFile],
|
||||
name: 'Alice',
|
||||
},
|
||||
promptConfig,
|
||||
visionConfig,
|
||||
})
|
||||
|
||||
expect(result).toEqual({
|
||||
files: [
|
||||
expect.objectContaining({
|
||||
transfer_method: TransferMethod.local_file,
|
||||
upload_file_id: 'upload-1',
|
||||
url: '',
|
||||
}),
|
||||
expect.objectContaining({
|
||||
transfer_method: TransferMethod.remote_url,
|
||||
url: 'https://example.com/remote.png',
|
||||
}),
|
||||
],
|
||||
inputs: {
|
||||
enabled: true,
|
||||
file: {
|
||||
type: 'document',
|
||||
transfer_method: TransferMethod.local_file,
|
||||
upload_file_id: 'uploaded-1',
|
||||
url: 'https://example.com/file.txt',
|
||||
},
|
||||
files: [
|
||||
{
|
||||
type: 'document',
|
||||
transfer_method: TransferMethod.local_file,
|
||||
upload_file_id: 'uploaded-1',
|
||||
url: 'https://example.com/file.txt',
|
||||
},
|
||||
{
|
||||
type: 'document',
|
||||
transfer_method: TransferMethod.local_file,
|
||||
upload_file_id: 'uploaded-2',
|
||||
url: 'https://example.com/second.txt',
|
||||
},
|
||||
],
|
||||
name: 'Alice',
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,901 @@
|
||||
import type { WorkflowProcess } from '@/app/components/base/chat/types'
|
||||
import type { IOtherOptions } from '@/service/base'
|
||||
import type { HumanInputFormData, HumanInputFormTimeoutData, NodeTracing } from '@/types/workflow'
|
||||
import { act } from '@testing-library/react'
|
||||
import { BlockEnum, NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types'
|
||||
import {
|
||||
appendParallelNext,
|
||||
appendParallelStart,
|
||||
appendResultText,
|
||||
applyWorkflowFinishedState,
|
||||
applyWorkflowOutputs,
|
||||
applyWorkflowPaused,
|
||||
createWorkflowStreamHandlers,
|
||||
finishParallelTrace,
|
||||
finishWorkflowNode,
|
||||
markNodesStopped,
|
||||
replaceResultText,
|
||||
updateHumanInputFilled,
|
||||
updateHumanInputRequired,
|
||||
updateHumanInputTimeout,
|
||||
upsertWorkflowNode,
|
||||
} from '../workflow-stream-handlers'
|
||||
|
||||
const sseGetMock = vi.fn()
|
||||
|
||||
type TraceOverrides = Omit<Partial<NodeTracing>, 'execution_metadata'> & {
|
||||
execution_metadata?: Partial<NonNullable<NodeTracing['execution_metadata']>>
|
||||
}
|
||||
|
||||
vi.mock('@/service/base', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/service/base')>('@/service/base')
|
||||
return {
|
||||
...actual,
|
||||
sseGet: (...args: Parameters<typeof actual.sseGet>) => sseGetMock(...args),
|
||||
}
|
||||
})
|
||||
|
||||
const createTrace = (overrides: TraceOverrides = {}): NodeTracing => {
|
||||
const { execution_metadata, ...restOverrides } = overrides
|
||||
|
||||
return {
|
||||
id: 'trace-1',
|
||||
index: 0,
|
||||
predecessor_node_id: '',
|
||||
node_id: 'node-1',
|
||||
node_type: BlockEnum.LLM,
|
||||
title: 'Node',
|
||||
inputs: {},
|
||||
inputs_truncated: false,
|
||||
process_data: {},
|
||||
process_data_truncated: false,
|
||||
outputs: {},
|
||||
outputs_truncated: false,
|
||||
status: NodeRunningStatus.Running,
|
||||
elapsed_time: 0,
|
||||
metadata: {
|
||||
iterator_length: 0,
|
||||
iterator_index: 0,
|
||||
loop_length: 0,
|
||||
loop_index: 0,
|
||||
},
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
details: [[]],
|
||||
execution_metadata: {
|
||||
total_tokens: 0,
|
||||
total_price: 0,
|
||||
currency: 'USD',
|
||||
...execution_metadata,
|
||||
},
|
||||
...restOverrides,
|
||||
}
|
||||
}
|
||||
|
||||
const createWorkflowProcess = (): WorkflowProcess => ({
|
||||
status: WorkflowRunningStatus.Running,
|
||||
tracing: [],
|
||||
expand: false,
|
||||
resultText: '',
|
||||
})
|
||||
|
||||
const createHumanInput = (overrides: Partial<HumanInputFormData> = {}): HumanInputFormData => ({
|
||||
form_id: 'form-1',
|
||||
node_id: 'node-1',
|
||||
node_title: 'Node',
|
||||
form_content: 'content',
|
||||
inputs: [],
|
||||
actions: [],
|
||||
form_token: 'token-1',
|
||||
resolved_default_values: {},
|
||||
display_in_ui: true,
|
||||
expiration_time: 100,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('workflow-stream-handlers helpers', () => {
|
||||
it('should update tracing, result text, and human input state', () => {
|
||||
const parallelTrace = createTrace({
|
||||
node_id: 'parallel-node',
|
||||
execution_metadata: { parallel_id: 'parallel-1' },
|
||||
details: [[]],
|
||||
})
|
||||
|
||||
let workflowProcessData = appendParallelStart(undefined, parallelTrace)
|
||||
workflowProcessData = appendParallelNext(workflowProcessData, parallelTrace)
|
||||
workflowProcessData = finishParallelTrace(workflowProcessData, createTrace({
|
||||
node_id: 'parallel-node',
|
||||
execution_metadata: { parallel_id: 'parallel-1' },
|
||||
error: 'failed',
|
||||
}))
|
||||
workflowProcessData = upsertWorkflowNode(workflowProcessData, createTrace({
|
||||
node_id: 'node-1',
|
||||
execution_metadata: { parallel_id: 'parallel-2' },
|
||||
}))!
|
||||
workflowProcessData = appendResultText(workflowProcessData, 'Hello ')
|
||||
workflowProcessData = replaceResultText(workflowProcessData, 'Hello world')
|
||||
workflowProcessData = updateHumanInputRequired(workflowProcessData, createHumanInput())
|
||||
workflowProcessData = updateHumanInputFilled(workflowProcessData, {
|
||||
action_id: 'action-1',
|
||||
action_text: 'Submit',
|
||||
node_id: 'node-1',
|
||||
node_title: 'Node',
|
||||
rendered_content: 'Done',
|
||||
})
|
||||
workflowProcessData = updateHumanInputTimeout(workflowProcessData, {
|
||||
node_id: 'node-1',
|
||||
node_title: 'Node',
|
||||
expiration_time: 200,
|
||||
} satisfies HumanInputFormTimeoutData)
|
||||
workflowProcessData = applyWorkflowPaused(workflowProcessData)
|
||||
|
||||
expect(workflowProcessData.expand).toBe(false)
|
||||
expect(workflowProcessData.resultText).toBe('Hello world')
|
||||
expect(workflowProcessData.humanInputFilledFormDataList).toEqual([
|
||||
expect.objectContaining({
|
||||
action_text: 'Submit',
|
||||
}),
|
||||
])
|
||||
expect(workflowProcessData.tracing[0]).toEqual(expect.objectContaining({
|
||||
node_id: 'parallel-node',
|
||||
expand: true,
|
||||
}))
|
||||
})
|
||||
|
||||
it('should initialize missing parallel details on start and next events', () => {
|
||||
const parallelTrace = createTrace({
|
||||
node_id: 'parallel-node',
|
||||
execution_metadata: { parallel_id: 'parallel-1' },
|
||||
})
|
||||
|
||||
const startedProcess = appendParallelStart(undefined, parallelTrace)
|
||||
const nextProcess = appendParallelNext(startedProcess, parallelTrace)
|
||||
|
||||
expect(startedProcess.tracing[0]?.details).toEqual([[]])
|
||||
expect(nextProcess.tracing[0]?.details).toEqual([[], []])
|
||||
})
|
||||
|
||||
it('should leave tracing unchanged when a parallel next event has no matching trace', () => {
|
||||
const process = createWorkflowProcess()
|
||||
process.tracing = [
|
||||
createTrace({
|
||||
node_id: 'parallel-node',
|
||||
execution_metadata: { parallel_id: 'parallel-1' },
|
||||
details: [[]],
|
||||
}),
|
||||
]
|
||||
|
||||
const nextProcess = appendParallelNext(process, createTrace({
|
||||
node_id: 'missing-node',
|
||||
execution_metadata: { parallel_id: 'parallel-2' },
|
||||
}))
|
||||
|
||||
expect(nextProcess.tracing).toEqual(process.tracing)
|
||||
expect(nextProcess.expand).toBe(true)
|
||||
})
|
||||
|
||||
it('should mark running nodes as stopped recursively', () => {
|
||||
const workflowProcessData = createWorkflowProcess()
|
||||
workflowProcessData.tracing = [
|
||||
createTrace({
|
||||
status: NodeRunningStatus.Running,
|
||||
details: [[createTrace({ status: NodeRunningStatus.Waiting })]],
|
||||
}),
|
||||
]
|
||||
|
||||
const stoppedWorkflow = applyWorkflowFinishedState(workflowProcessData, WorkflowRunningStatus.Stopped)
|
||||
markNodesStopped(stoppedWorkflow.tracing)
|
||||
|
||||
expect(stoppedWorkflow.status).toBe(WorkflowRunningStatus.Stopped)
|
||||
expect(stoppedWorkflow.tracing[0].status).toBe(NodeRunningStatus.Stopped)
|
||||
expect(stoppedWorkflow.tracing[0].details?.[0][0].status).toBe(NodeRunningStatus.Stopped)
|
||||
})
|
||||
|
||||
it('should cover unmatched and replacement helper branches', () => {
|
||||
const process = createWorkflowProcess()
|
||||
process.tracing = [
|
||||
createTrace({
|
||||
node_id: 'node-1',
|
||||
parallel_id: 'parallel-1',
|
||||
extras: {
|
||||
source: 'extra',
|
||||
},
|
||||
status: NodeRunningStatus.Succeeded,
|
||||
}),
|
||||
]
|
||||
process.humanInputFormDataList = [
|
||||
createHumanInput({ node_id: 'node-1' }),
|
||||
]
|
||||
process.humanInputFilledFormDataList = [
|
||||
{
|
||||
action_id: 'action-0',
|
||||
action_text: 'Existing',
|
||||
node_id: 'node-0',
|
||||
node_title: 'Node 0',
|
||||
rendered_content: 'Existing',
|
||||
},
|
||||
]
|
||||
|
||||
const parallelMatched = appendParallelNext(process, createTrace({
|
||||
node_id: 'node-1',
|
||||
execution_metadata: {
|
||||
parallel_id: 'parallel-1',
|
||||
},
|
||||
}))
|
||||
const notFinished = finishParallelTrace(process, createTrace({
|
||||
node_id: 'missing',
|
||||
execution_metadata: {
|
||||
parallel_id: 'parallel-missing',
|
||||
},
|
||||
}))
|
||||
const ignoredIteration = upsertWorkflowNode(process, createTrace({
|
||||
iteration_id: 'iteration-1',
|
||||
}))
|
||||
const replacedNode = upsertWorkflowNode(process, createTrace({
|
||||
node_id: 'node-1',
|
||||
}))
|
||||
const ignoredFinish = finishWorkflowNode(process, createTrace({
|
||||
loop_id: 'loop-1',
|
||||
}))
|
||||
const unmatchedFinish = finishWorkflowNode(process, createTrace({
|
||||
node_id: 'missing',
|
||||
execution_metadata: {
|
||||
parallel_id: 'missing',
|
||||
},
|
||||
}))
|
||||
const finishedWithExtras = finishWorkflowNode(process, createTrace({
|
||||
node_id: 'node-1',
|
||||
execution_metadata: {
|
||||
parallel_id: 'parallel-1',
|
||||
},
|
||||
error: 'failed',
|
||||
}))
|
||||
const succeededWorkflow = applyWorkflowFinishedState(process, WorkflowRunningStatus.Succeeded)
|
||||
const outputlessWorkflow = applyWorkflowOutputs(undefined, null)
|
||||
const updatedHumanInput = updateHumanInputRequired(process, createHumanInput({
|
||||
node_id: 'node-1',
|
||||
expiration_time: 300,
|
||||
}))
|
||||
const appendedHumanInput = updateHumanInputRequired(process, createHumanInput({
|
||||
node_id: 'node-2',
|
||||
}))
|
||||
const noListFilled = updateHumanInputFilled(undefined, {
|
||||
action_id: 'action-1',
|
||||
action_text: 'Submit',
|
||||
node_id: 'node-1',
|
||||
node_title: 'Node',
|
||||
rendered_content: 'Done',
|
||||
})
|
||||
const appendedFilled = updateHumanInputFilled(process, {
|
||||
action_id: 'action-2',
|
||||
action_text: 'Append',
|
||||
node_id: 'node-2',
|
||||
node_title: 'Node 2',
|
||||
rendered_content: 'More',
|
||||
})
|
||||
const timeoutWithoutList = updateHumanInputTimeout(undefined, {
|
||||
node_id: 'node-1',
|
||||
node_title: 'Node',
|
||||
expiration_time: 200,
|
||||
})
|
||||
const timeoutWithMatch = updateHumanInputTimeout(process, {
|
||||
node_id: 'node-1',
|
||||
node_title: 'Node',
|
||||
expiration_time: 400,
|
||||
})
|
||||
|
||||
markNodesStopped(undefined)
|
||||
|
||||
expect(parallelMatched.tracing[0].details).toHaveLength(2)
|
||||
expect(notFinished).toEqual(expect.objectContaining({
|
||||
expand: true,
|
||||
tracing: process.tracing,
|
||||
}))
|
||||
expect(ignoredIteration).toEqual(process)
|
||||
expect(replacedNode?.tracing[0]).toEqual(expect.objectContaining({
|
||||
node_id: 'node-1',
|
||||
status: NodeRunningStatus.Running,
|
||||
}))
|
||||
expect(ignoredFinish).toEqual(process)
|
||||
expect(unmatchedFinish).toEqual(process)
|
||||
expect(finishedWithExtras?.tracing[0]).toEqual(expect.objectContaining({
|
||||
extras: {
|
||||
source: 'extra',
|
||||
},
|
||||
error: 'failed',
|
||||
}))
|
||||
expect(succeededWorkflow.status).toBe(WorkflowRunningStatus.Succeeded)
|
||||
expect(outputlessWorkflow.files).toEqual([])
|
||||
expect(updatedHumanInput.humanInputFormDataList?.[0].expiration_time).toBe(300)
|
||||
expect(appendedHumanInput.humanInputFormDataList).toHaveLength(2)
|
||||
expect(noListFilled.humanInputFilledFormDataList).toHaveLength(1)
|
||||
expect(appendedFilled.humanInputFilledFormDataList).toHaveLength(2)
|
||||
expect(timeoutWithoutList).toEqual(expect.objectContaining({
|
||||
status: WorkflowRunningStatus.Running,
|
||||
tracing: [],
|
||||
}))
|
||||
expect(timeoutWithMatch.humanInputFormDataList?.[0].expiration_time).toBe(400)
|
||||
})
|
||||
})
|
||||
|
||||
describe('createWorkflowStreamHandlers', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
const setupHandlers = (overrides: { isTimedOut?: () => boolean } = {}) => {
|
||||
let completionRes = ''
|
||||
let currentTaskId: string | null = null
|
||||
let isStopping = false
|
||||
let messageId: string | null = null
|
||||
let workflowProcessData: WorkflowProcess | undefined
|
||||
|
||||
const setCurrentTaskId = vi.fn((value: string | null | ((prev: string | null) => string | null)) => {
|
||||
currentTaskId = typeof value === 'function' ? value(currentTaskId) : value
|
||||
})
|
||||
const setIsStopping = vi.fn((value: boolean | ((prev: boolean) => boolean)) => {
|
||||
isStopping = typeof value === 'function' ? value(isStopping) : value
|
||||
})
|
||||
const setMessageId = vi.fn((value: string | null | ((prev: string | null) => string | null)) => {
|
||||
messageId = typeof value === 'function' ? value(messageId) : value
|
||||
})
|
||||
const setWorkflowProcessData = vi.fn((value: WorkflowProcess | undefined) => {
|
||||
workflowProcessData = value
|
||||
})
|
||||
const setCompletionRes = vi.fn((value: string) => {
|
||||
completionRes = value
|
||||
})
|
||||
const notify = vi.fn()
|
||||
const onCompleted = vi.fn()
|
||||
const resetRunState = vi.fn()
|
||||
const setRespondingFalse = vi.fn()
|
||||
const markEnded = vi.fn()
|
||||
|
||||
const handlers = createWorkflowStreamHandlers({
|
||||
getCompletionRes: () => completionRes,
|
||||
getWorkflowProcessData: () => workflowProcessData,
|
||||
isTimedOut: overrides.isTimedOut ?? (() => false),
|
||||
markEnded,
|
||||
notify,
|
||||
onCompleted,
|
||||
resetRunState,
|
||||
setCompletionRes,
|
||||
setCurrentTaskId,
|
||||
setIsStopping,
|
||||
setMessageId,
|
||||
setRespondingFalse,
|
||||
setWorkflowProcessData,
|
||||
t: (key: string) => key,
|
||||
taskId: 3,
|
||||
})
|
||||
|
||||
return {
|
||||
currentTaskId: () => currentTaskId,
|
||||
handlers,
|
||||
isStopping: () => isStopping,
|
||||
messageId: () => messageId,
|
||||
notify,
|
||||
onCompleted,
|
||||
resetRunState,
|
||||
setCompletionRes,
|
||||
setCurrentTaskId,
|
||||
setMessageId,
|
||||
setRespondingFalse,
|
||||
workflowProcessData: () => workflowProcessData,
|
||||
}
|
||||
}
|
||||
|
||||
it('should process workflow success and paused events', () => {
|
||||
const setup = setupHandlers()
|
||||
const handlers = setup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onTextChunk' | 'onHumanInputRequired' | 'onHumanInputFormFilled' | 'onHumanInputFormTimeout' | 'onWorkflowPaused' | 'onWorkflowFinished' | 'onNodeStarted' | 'onNodeFinished' | 'onIterationStart' | 'onIterationNext' | 'onIterationFinish' | 'onLoopStart' | 'onLoopNext' | 'onLoopFinish'>>
|
||||
|
||||
act(() => {
|
||||
handlers.onWorkflowStarted({
|
||||
workflow_run_id: 'run-1',
|
||||
task_id: 'task-1',
|
||||
event: 'workflow_started',
|
||||
data: { id: 'run-1', workflow_id: 'wf-1', created_at: 0 },
|
||||
})
|
||||
handlers.onNodeStarted({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'node_started',
|
||||
data: createTrace({ node_id: 'node-1' }),
|
||||
})
|
||||
handlers.onNodeFinished({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'node_finished',
|
||||
data: createTrace({ node_id: 'node-1', error: '' }),
|
||||
})
|
||||
handlers.onIterationStart({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'iteration_start',
|
||||
data: createTrace({
|
||||
node_id: 'iter-1',
|
||||
execution_metadata: { parallel_id: 'parallel-1' },
|
||||
details: [[]],
|
||||
}),
|
||||
})
|
||||
handlers.onIterationNext({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'iteration_next',
|
||||
data: createTrace({
|
||||
node_id: 'iter-1',
|
||||
execution_metadata: { parallel_id: 'parallel-1' },
|
||||
details: [[]],
|
||||
}),
|
||||
})
|
||||
handlers.onIterationFinish({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'iteration_finish',
|
||||
data: createTrace({
|
||||
node_id: 'iter-1',
|
||||
execution_metadata: { parallel_id: 'parallel-1' },
|
||||
}),
|
||||
})
|
||||
handlers.onLoopStart({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'loop_start',
|
||||
data: createTrace({
|
||||
node_id: 'loop-1',
|
||||
execution_metadata: { parallel_id: 'parallel-2' },
|
||||
details: [[]],
|
||||
}),
|
||||
})
|
||||
handlers.onLoopNext({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'loop_next',
|
||||
data: createTrace({
|
||||
node_id: 'loop-1',
|
||||
execution_metadata: { parallel_id: 'parallel-2' },
|
||||
details: [[]],
|
||||
}),
|
||||
})
|
||||
handlers.onLoopFinish({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'loop_finish',
|
||||
data: createTrace({
|
||||
node_id: 'loop-1',
|
||||
execution_metadata: { parallel_id: 'parallel-2' },
|
||||
}),
|
||||
})
|
||||
handlers.onTextChunk({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'text_chunk',
|
||||
data: { text: 'Hello' },
|
||||
})
|
||||
handlers.onHumanInputRequired({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'human_input_required',
|
||||
data: createHumanInput({ node_id: 'node-1' }),
|
||||
})
|
||||
handlers.onHumanInputFormFilled({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'human_input_form_filled',
|
||||
data: {
|
||||
node_id: 'node-1',
|
||||
node_title: 'Node',
|
||||
rendered_content: 'Done',
|
||||
action_id: 'action-1',
|
||||
action_text: 'Submit',
|
||||
},
|
||||
})
|
||||
handlers.onHumanInputFormTimeout({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'human_input_form_timeout',
|
||||
data: {
|
||||
node_id: 'node-1',
|
||||
node_title: 'Node',
|
||||
expiration_time: 200,
|
||||
},
|
||||
})
|
||||
handlers.onWorkflowPaused({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'workflow_paused',
|
||||
data: {
|
||||
outputs: {},
|
||||
paused_nodes: [],
|
||||
reasons: [],
|
||||
workflow_run_id: 'run-1',
|
||||
},
|
||||
})
|
||||
handlers.onWorkflowFinished({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'workflow_finished',
|
||||
data: {
|
||||
id: 'run-1',
|
||||
workflow_id: 'wf-1',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
outputs: { answer: 'Hello' },
|
||||
error: '',
|
||||
elapsed_time: 0,
|
||||
total_tokens: 0,
|
||||
total_steps: 0,
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(setup.currentTaskId()).toBe('task-1')
|
||||
expect(setup.isStopping()).toBe(false)
|
||||
expect(setup.workflowProcessData()).toEqual(expect.objectContaining({
|
||||
resultText: 'Hello',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
}))
|
||||
expect(sseGetMock).toHaveBeenCalledWith('/workflow/run-1/events', {}, expect.any(Object))
|
||||
expect(setup.messageId()).toBe('run-1')
|
||||
expect(setup.onCompleted).toHaveBeenCalledWith('{"answer":"Hello"}', 3, true)
|
||||
expect(setup.setRespondingFalse).toHaveBeenCalled()
|
||||
expect(setup.resetRunState).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle timeout and workflow failures', () => {
|
||||
const timeoutSetup = setupHandlers({
|
||||
isTimedOut: () => true,
|
||||
})
|
||||
const timeoutHandlers = timeoutSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowFinished'>>
|
||||
|
||||
act(() => {
|
||||
timeoutHandlers.onWorkflowFinished({
|
||||
task_id: 'task-1',
|
||||
workflow_run_id: 'run-1',
|
||||
event: 'workflow_finished',
|
||||
data: {
|
||||
id: 'run-1',
|
||||
workflow_id: 'wf-1',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
outputs: null,
|
||||
error: '',
|
||||
elapsed_time: 0,
|
||||
total_tokens: 0,
|
||||
total_steps: 0,
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(timeoutSetup.notify).toHaveBeenCalledWith({
|
||||
type: 'warning',
|
||||
message: 'warningMessage.timeoutExceeded',
|
||||
})
|
||||
|
||||
const failureSetup = setupHandlers()
|
||||
const failureHandlers = failureSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onWorkflowFinished'>>
|
||||
|
||||
act(() => {
|
||||
failureHandlers.onWorkflowStarted({
|
||||
workflow_run_id: 'run-2',
|
||||
task_id: 'task-2',
|
||||
event: 'workflow_started',
|
||||
data: { id: 'run-2', workflow_id: 'wf-2', created_at: 0 },
|
||||
})
|
||||
failureHandlers.onWorkflowFinished({
|
||||
task_id: 'task-2',
|
||||
workflow_run_id: 'run-2',
|
||||
event: 'workflow_finished',
|
||||
data: {
|
||||
id: 'run-2',
|
||||
workflow_id: 'wf-2',
|
||||
status: WorkflowRunningStatus.Failed,
|
||||
outputs: null,
|
||||
error: 'failed',
|
||||
elapsed_time: 0,
|
||||
total_tokens: 0,
|
||||
total_steps: 0,
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(failureSetup.notify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'failed',
|
||||
})
|
||||
expect(failureSetup.onCompleted).toHaveBeenCalledWith('', 3, false)
|
||||
})
|
||||
|
||||
it('should cover existing workflow starts, stopped runs, and non-string outputs', () => {
|
||||
const setup = setupHandlers()
|
||||
let existingProcess: WorkflowProcess = {
|
||||
status: WorkflowRunningStatus.Paused,
|
||||
tracing: [
|
||||
createTrace({
|
||||
node_id: 'existing-node',
|
||||
status: NodeRunningStatus.Waiting,
|
||||
}),
|
||||
],
|
||||
expand: false,
|
||||
resultText: '',
|
||||
}
|
||||
|
||||
const handlers = createWorkflowStreamHandlers({
|
||||
getCompletionRes: () => '',
|
||||
getWorkflowProcessData: () => existingProcess,
|
||||
isTimedOut: () => false,
|
||||
markEnded: vi.fn(),
|
||||
notify: setup.notify,
|
||||
onCompleted: setup.onCompleted,
|
||||
resetRunState: setup.resetRunState,
|
||||
setCompletionRes: setup.setCompletionRes,
|
||||
setCurrentTaskId: setup.setCurrentTaskId,
|
||||
setIsStopping: vi.fn(),
|
||||
setMessageId: setup.setMessageId,
|
||||
setRespondingFalse: setup.setRespondingFalse,
|
||||
setWorkflowProcessData: (value) => {
|
||||
existingProcess = value!
|
||||
},
|
||||
t: (key: string) => key,
|
||||
taskId: 5,
|
||||
}) as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onWorkflowFinished' | 'onTextReplace'>>
|
||||
|
||||
act(() => {
|
||||
handlers.onWorkflowStarted({
|
||||
workflow_run_id: 'run-existing',
|
||||
task_id: '',
|
||||
event: 'workflow_started',
|
||||
data: { id: 'run-existing', workflow_id: 'wf-1', created_at: 0 },
|
||||
})
|
||||
handlers.onTextReplace({
|
||||
task_id: 'task-existing',
|
||||
workflow_run_id: 'run-existing',
|
||||
event: 'text_replace',
|
||||
data: { text: 'Replaced text' },
|
||||
})
|
||||
})
|
||||
|
||||
expect(existingProcess).toEqual(expect.objectContaining({
|
||||
expand: true,
|
||||
status: WorkflowRunningStatus.Running,
|
||||
resultText: 'Replaced text',
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
handlers.onWorkflowFinished({
|
||||
task_id: 'task-existing',
|
||||
workflow_run_id: 'run-existing',
|
||||
event: 'workflow_finished',
|
||||
data: {
|
||||
id: 'run-existing',
|
||||
workflow_id: 'wf-1',
|
||||
status: WorkflowRunningStatus.Stopped,
|
||||
outputs: null,
|
||||
error: '',
|
||||
elapsed_time: 0,
|
||||
total_tokens: 0,
|
||||
total_steps: 0,
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(existingProcess.status).toBe(WorkflowRunningStatus.Stopped)
|
||||
expect(existingProcess.tracing[0].status).toBe(NodeRunningStatus.Stopped)
|
||||
expect(setup.onCompleted).toHaveBeenCalledWith('', 5, false)
|
||||
|
||||
const noOutputSetup = setupHandlers()
|
||||
const noOutputHandlers = noOutputSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onWorkflowFinished' | 'onTextReplace'>>
|
||||
|
||||
act(() => {
|
||||
noOutputHandlers.onWorkflowStarted({
|
||||
workflow_run_id: 'run-no-output',
|
||||
task_id: '',
|
||||
event: 'workflow_started',
|
||||
data: { id: 'run-no-output', workflow_id: 'wf-2', created_at: 0 },
|
||||
})
|
||||
noOutputHandlers.onTextReplace({
|
||||
task_id: 'task-no-output',
|
||||
workflow_run_id: 'run-no-output',
|
||||
event: 'text_replace',
|
||||
data: { text: 'Draft' },
|
||||
})
|
||||
noOutputHandlers.onWorkflowFinished({
|
||||
task_id: 'task-no-output',
|
||||
workflow_run_id: 'run-no-output',
|
||||
event: 'workflow_finished',
|
||||
data: {
|
||||
id: 'run-no-output',
|
||||
workflow_id: 'wf-2',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
outputs: null,
|
||||
error: '',
|
||||
elapsed_time: 0,
|
||||
total_tokens: 0,
|
||||
total_steps: 0,
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(noOutputSetup.setCompletionRes).toHaveBeenCalledWith('')
|
||||
|
||||
const objectOutputSetup = setupHandlers()
|
||||
const objectOutputHandlers = objectOutputSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onWorkflowFinished'>>
|
||||
|
||||
act(() => {
|
||||
objectOutputHandlers.onWorkflowStarted({
|
||||
workflow_run_id: 'run-object',
|
||||
task_id: undefined as unknown as string,
|
||||
event: 'workflow_started',
|
||||
data: { id: 'run-object', workflow_id: 'wf-3', created_at: 0 },
|
||||
})
|
||||
objectOutputHandlers.onWorkflowFinished({
|
||||
task_id: 'task-object',
|
||||
workflow_run_id: 'run-object',
|
||||
event: 'workflow_finished',
|
||||
data: {
|
||||
id: 'run-object',
|
||||
workflow_id: 'wf-3',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
outputs: {
|
||||
answer: 'Hello',
|
||||
meta: {
|
||||
mode: 'object',
|
||||
},
|
||||
},
|
||||
error: '',
|
||||
elapsed_time: 0,
|
||||
total_tokens: 0,
|
||||
total_steps: 0,
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(objectOutputSetup.currentTaskId()).toBeNull()
|
||||
expect(objectOutputSetup.setCompletionRes).toHaveBeenCalledWith('{"answer":"Hello","meta":{"mode":"object"}}')
|
||||
expect(objectOutputSetup.workflowProcessData()).toEqual(expect.objectContaining({
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
resultText: '',
|
||||
}))
|
||||
})
|
||||
|
||||
it('should serialize empty, string, and circular workflow outputs', () => {
|
||||
const noOutputSetup = setupHandlers()
|
||||
const noOutputHandlers = noOutputSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowFinished'>>
|
||||
|
||||
act(() => {
|
||||
noOutputHandlers.onWorkflowFinished({
|
||||
task_id: 'task-empty',
|
||||
workflow_run_id: 'run-empty',
|
||||
event: 'workflow_finished',
|
||||
data: {
|
||||
id: 'run-empty',
|
||||
workflow_id: 'wf-empty',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
outputs: null,
|
||||
error: '',
|
||||
elapsed_time: 0,
|
||||
total_tokens: 0,
|
||||
total_steps: 0,
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(noOutputSetup.setCompletionRes).toHaveBeenCalledWith('')
|
||||
|
||||
const stringOutputSetup = setupHandlers()
|
||||
const stringOutputHandlers = stringOutputSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowFinished'>>
|
||||
|
||||
act(() => {
|
||||
stringOutputHandlers.onWorkflowFinished({
|
||||
task_id: 'task-string',
|
||||
workflow_run_id: 'run-string',
|
||||
event: 'workflow_finished',
|
||||
data: {
|
||||
id: 'run-string',
|
||||
workflow_id: 'wf-string',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
outputs: 'plain text output',
|
||||
error: '',
|
||||
elapsed_time: 0,
|
||||
total_tokens: 0,
|
||||
total_steps: 0,
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(stringOutputSetup.setCompletionRes).toHaveBeenCalledWith('plain text output')
|
||||
|
||||
const circularOutputSetup = setupHandlers()
|
||||
const circularOutputHandlers = circularOutputSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowFinished'>>
|
||||
const circularOutputs: Record<string, unknown> = {
|
||||
answer: 'Hello',
|
||||
}
|
||||
circularOutputs.self = circularOutputs
|
||||
|
||||
act(() => {
|
||||
circularOutputHandlers.onWorkflowFinished({
|
||||
task_id: 'task-circular',
|
||||
workflow_run_id: 'run-circular',
|
||||
event: 'workflow_finished',
|
||||
data: {
|
||||
id: 'run-circular',
|
||||
workflow_id: 'wf-circular',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
outputs: circularOutputs,
|
||||
error: '',
|
||||
elapsed_time: 0,
|
||||
total_tokens: 0,
|
||||
total_steps: 0,
|
||||
created_at: 0,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'User',
|
||||
email: 'user@example.com',
|
||||
},
|
||||
finished_at: 0,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
expect(circularOutputSetup.setCompletionRes).toHaveBeenCalledWith('[object Object]')
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,200 @@
|
||||
import type { FeedbackType } from '@/app/components/base/chat/chat/type'
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { AppSourceType } from '@/service/share'
|
||||
import { useResultRunState } from '../use-result-run-state'
|
||||
|
||||
const {
|
||||
stopChatMessageRespondingMock,
|
||||
stopWorkflowMessageMock,
|
||||
updateFeedbackMock,
|
||||
} = vi.hoisted(() => ({
|
||||
stopChatMessageRespondingMock: vi.fn(),
|
||||
stopWorkflowMessageMock: vi.fn(),
|
||||
updateFeedbackMock: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/share', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/service/share')>('@/service/share')
|
||||
return {
|
||||
...actual,
|
||||
stopChatMessageResponding: (...args: Parameters<typeof actual.stopChatMessageResponding>) => stopChatMessageRespondingMock(...args),
|
||||
stopWorkflowMessage: (...args: Parameters<typeof actual.stopWorkflowMessage>) => stopWorkflowMessageMock(...args),
|
||||
updateFeedback: (...args: Parameters<typeof actual.updateFeedback>) => updateFeedbackMock(...args),
|
||||
}
|
||||
})
|
||||
|
||||
describe('useResultRunState', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
stopChatMessageRespondingMock.mockResolvedValue(undefined)
|
||||
stopWorkflowMessageMock.mockResolvedValue(undefined)
|
||||
updateFeedbackMock.mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
it('should expose run control and stop completion requests', async () => {
|
||||
const notify = vi.fn()
|
||||
const onRunControlChange = vi.fn()
|
||||
const { result } = renderHook(() => useResultRunState({
|
||||
appId: 'app-1',
|
||||
appSourceType: AppSourceType.webApp,
|
||||
controlStopResponding: 0,
|
||||
isWorkflow: false,
|
||||
notify,
|
||||
onRunControlChange,
|
||||
}))
|
||||
|
||||
const abort = vi.fn()
|
||||
|
||||
act(() => {
|
||||
result.current.abortControllerRef.current = { abort } as unknown as AbortController
|
||||
result.current.setCurrentTaskId('task-1')
|
||||
result.current.setRespondingTrue()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onRunControlChange).toHaveBeenLastCalledWith(expect.objectContaining({
|
||||
isStopping: false,
|
||||
}))
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleStop()
|
||||
})
|
||||
|
||||
expect(stopChatMessageRespondingMock).toHaveBeenCalledWith('app-1', 'task-1', AppSourceType.webApp, 'app-1')
|
||||
expect(abort).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should update feedback and react to external stop control', async () => {
|
||||
const notify = vi.fn()
|
||||
const onRunControlChange = vi.fn()
|
||||
const { result, rerender } = renderHook(({ controlStopResponding }) => useResultRunState({
|
||||
appId: 'app-2',
|
||||
appSourceType: AppSourceType.installedApp,
|
||||
controlStopResponding,
|
||||
isWorkflow: true,
|
||||
notify,
|
||||
onRunControlChange,
|
||||
}), {
|
||||
initialProps: { controlStopResponding: 0 },
|
||||
})
|
||||
|
||||
const abort = vi.fn()
|
||||
act(() => {
|
||||
result.current.abortControllerRef.current = { abort } as unknown as AbortController
|
||||
result.current.setMessageId('message-1')
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleFeedback({
|
||||
rating: 'like',
|
||||
} satisfies FeedbackType)
|
||||
})
|
||||
|
||||
expect(updateFeedbackMock).toHaveBeenCalledWith({
|
||||
url: '/messages/message-1/feedbacks',
|
||||
body: {
|
||||
rating: 'like',
|
||||
content: undefined,
|
||||
},
|
||||
}, AppSourceType.installedApp, 'app-2')
|
||||
expect(result.current.feedback).toEqual({
|
||||
rating: 'like',
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.setCurrentTaskId('task-2')
|
||||
result.current.setRespondingTrue()
|
||||
})
|
||||
|
||||
rerender({ controlStopResponding: 1 })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(abort).toHaveBeenCalled()
|
||||
expect(result.current.currentTaskId).toBeNull()
|
||||
expect(onRunControlChange).toHaveBeenLastCalledWith(null)
|
||||
})
|
||||
})
|
||||
|
||||
it('should stop workflow requests through the workflow stop API', async () => {
|
||||
const notify = vi.fn()
|
||||
const { result } = renderHook(() => useResultRunState({
|
||||
appId: 'app-3',
|
||||
appSourceType: AppSourceType.installedApp,
|
||||
controlStopResponding: 0,
|
||||
isWorkflow: true,
|
||||
notify,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.setCurrentTaskId('task-3')
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleStop()
|
||||
})
|
||||
|
||||
expect(stopWorkflowMessageMock).toHaveBeenCalledWith('app-3', 'task-3', AppSourceType.installedApp, 'app-3')
|
||||
})
|
||||
|
||||
it('should ignore invalid stops and report non-Error failures', async () => {
|
||||
const notify = vi.fn()
|
||||
stopChatMessageRespondingMock.mockRejectedValueOnce('stop failed')
|
||||
|
||||
const { result } = renderHook(() => useResultRunState({
|
||||
appSourceType: AppSourceType.webApp,
|
||||
controlStopResponding: 0,
|
||||
isWorkflow: false,
|
||||
notify,
|
||||
}))
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleStop()
|
||||
})
|
||||
|
||||
expect(stopChatMessageRespondingMock).not.toHaveBeenCalled()
|
||||
|
||||
act(() => {
|
||||
result.current.setCurrentTaskId('task-4')
|
||||
result.current.setIsStopping(prev => !prev)
|
||||
result.current.setIsStopping(prev => !prev)
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleStop()
|
||||
})
|
||||
|
||||
expect(stopChatMessageRespondingMock).toHaveBeenCalledWith(undefined, 'task-4', AppSourceType.webApp, '')
|
||||
expect(notify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'stop failed',
|
||||
})
|
||||
expect(result.current.isStopping).toBe(false)
|
||||
})
|
||||
|
||||
it('should report Error instances from workflow stop failures without an app id fallback', async () => {
|
||||
const notify = vi.fn()
|
||||
stopWorkflowMessageMock.mockRejectedValueOnce(new Error('workflow stop failed'))
|
||||
|
||||
const { result } = renderHook(() => useResultRunState({
|
||||
appSourceType: AppSourceType.installedApp,
|
||||
controlStopResponding: 0,
|
||||
isWorkflow: true,
|
||||
notify,
|
||||
}))
|
||||
|
||||
act(() => {
|
||||
result.current.setCurrentTaskId('task-5')
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleStop()
|
||||
})
|
||||
|
||||
expect(stopWorkflowMessageMock).toHaveBeenCalledWith(undefined, 'task-5', AppSourceType.installedApp, '')
|
||||
expect(notify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'workflow stop failed',
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,510 @@
|
||||
import type { ResultInputValue } from '../../result-request'
|
||||
import type { ResultRunStateController } from '../use-result-run-state'
|
||||
import type { PromptConfig } from '@/models/debug'
|
||||
import type { AppSourceType } from '@/service/share'
|
||||
import type { VisionSettings } from '@/types/app'
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { AppSourceType as AppSourceTypeEnum } from '@/service/share'
|
||||
import { Resolution, TransferMethod } from '@/types/app'
|
||||
import { useResultSender } from '../use-result-sender'
|
||||
|
||||
const {
|
||||
buildResultRequestDataMock,
|
||||
createWorkflowStreamHandlersMock,
|
||||
sendCompletionMessageMock,
|
||||
sendWorkflowMessageMock,
|
||||
sleepMock,
|
||||
validateResultRequestMock,
|
||||
} = vi.hoisted(() => ({
|
||||
buildResultRequestDataMock: vi.fn(),
|
||||
createWorkflowStreamHandlersMock: vi.fn(),
|
||||
sendCompletionMessageMock: vi.fn(),
|
||||
sendWorkflowMessageMock: vi.fn(),
|
||||
sleepMock: vi.fn(),
|
||||
validateResultRequestMock: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/share', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/service/share')>('@/service/share')
|
||||
return {
|
||||
...actual,
|
||||
sendCompletionMessage: (...args: Parameters<typeof actual.sendCompletionMessage>) => sendCompletionMessageMock(...args),
|
||||
sendWorkflowMessage: (...args: Parameters<typeof actual.sendWorkflowMessage>) => sendWorkflowMessageMock(...args),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/utils', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/utils')>('@/utils')
|
||||
return {
|
||||
...actual,
|
||||
sleep: (...args: Parameters<typeof actual.sleep>) => sleepMock(...args),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../../result-request', () => ({
|
||||
buildResultRequestData: (...args: unknown[]) => buildResultRequestDataMock(...args),
|
||||
validateResultRequest: (...args: unknown[]) => validateResultRequestMock(...args),
|
||||
}))
|
||||
|
||||
vi.mock('../../workflow-stream-handlers', () => ({
|
||||
createWorkflowStreamHandlers: (...args: unknown[]) => createWorkflowStreamHandlersMock(...args),
|
||||
}))
|
||||
|
||||
type RunStateHarness = {
|
||||
state: {
|
||||
completionRes: string
|
||||
currentTaskId: string | null
|
||||
messageId: string | null
|
||||
workflowProcessData: ResultRunStateController['workflowProcessData']
|
||||
}
|
||||
runState: ResultRunStateController
|
||||
}
|
||||
|
||||
type CompletionHandlers = {
|
||||
getAbortController: (abortController: AbortController) => void
|
||||
onCompleted: () => void
|
||||
onData: (chunk: string, isFirstMessage: boolean, info: { messageId: string, taskId?: string }) => void
|
||||
onError: () => void
|
||||
onMessageReplace: (messageReplace: { answer: string }) => void
|
||||
}
|
||||
|
||||
const createRunStateHarness = (): RunStateHarness => {
|
||||
const state: RunStateHarness['state'] = {
|
||||
completionRes: '',
|
||||
currentTaskId: null,
|
||||
messageId: null,
|
||||
workflowProcessData: undefined,
|
||||
}
|
||||
|
||||
const runState: ResultRunStateController = {
|
||||
abortControllerRef: { current: null },
|
||||
clearMoreLikeThis: vi.fn(),
|
||||
completionRes: '',
|
||||
controlClearMoreLikeThis: 0,
|
||||
currentTaskId: null,
|
||||
feedback: { rating: null },
|
||||
getCompletionRes: vi.fn(() => state.completionRes),
|
||||
getWorkflowProcessData: vi.fn(() => state.workflowProcessData),
|
||||
handleFeedback: vi.fn(),
|
||||
handleStop: vi.fn(),
|
||||
isResponding: false,
|
||||
isStopping: false,
|
||||
messageId: null,
|
||||
prepareForNewRun: vi.fn(() => {
|
||||
state.completionRes = ''
|
||||
state.currentTaskId = null
|
||||
state.messageId = null
|
||||
state.workflowProcessData = undefined
|
||||
runState.completionRes = ''
|
||||
runState.currentTaskId = null
|
||||
runState.messageId = null
|
||||
runState.workflowProcessData = undefined
|
||||
}),
|
||||
resetRunState: vi.fn(() => {
|
||||
state.currentTaskId = null
|
||||
runState.currentTaskId = null
|
||||
runState.isStopping = false
|
||||
}),
|
||||
setCompletionRes: vi.fn((value: string) => {
|
||||
state.completionRes = value
|
||||
runState.completionRes = value
|
||||
}),
|
||||
setCurrentTaskId: vi.fn((value) => {
|
||||
state.currentTaskId = typeof value === 'function' ? value(state.currentTaskId) : value
|
||||
runState.currentTaskId = state.currentTaskId
|
||||
}),
|
||||
setIsStopping: vi.fn((value) => {
|
||||
runState.isStopping = typeof value === 'function' ? value(runState.isStopping) : value
|
||||
}),
|
||||
setMessageId: vi.fn((value) => {
|
||||
state.messageId = typeof value === 'function' ? value(state.messageId) : value
|
||||
runState.messageId = state.messageId
|
||||
}),
|
||||
setRespondingFalse: vi.fn(() => {
|
||||
runState.isResponding = false
|
||||
}),
|
||||
setRespondingTrue: vi.fn(() => {
|
||||
runState.isResponding = true
|
||||
}),
|
||||
setWorkflowProcessData: vi.fn((value) => {
|
||||
state.workflowProcessData = value
|
||||
runState.workflowProcessData = value
|
||||
}),
|
||||
workflowProcessData: undefined,
|
||||
}
|
||||
|
||||
return {
|
||||
state,
|
||||
runState,
|
||||
}
|
||||
}
|
||||
|
||||
const promptConfig: PromptConfig = {
|
||||
prompt_template: 'template',
|
||||
prompt_variables: [
|
||||
{ key: 'name', name: 'Name', type: 'string', required: true },
|
||||
],
|
||||
}
|
||||
|
||||
const visionConfig: VisionSettings = {
|
||||
enabled: false,
|
||||
number_limits: 2,
|
||||
detail: Resolution.low,
|
||||
transfer_methods: [TransferMethod.local_file],
|
||||
}
|
||||
|
||||
type RenderSenderOptions = {
|
||||
appSourceType?: AppSourceType
|
||||
controlRetry?: number
|
||||
controlSend?: number
|
||||
inputs?: Record<string, ResultInputValue>
|
||||
isPC?: boolean
|
||||
isWorkflow?: boolean
|
||||
runState?: ResultRunStateController
|
||||
taskId?: number
|
||||
}
|
||||
|
||||
const renderSender = ({
|
||||
appSourceType = AppSourceTypeEnum.webApp,
|
||||
controlRetry = 0,
|
||||
controlSend = 0,
|
||||
inputs = { name: 'Alice' },
|
||||
isPC = true,
|
||||
isWorkflow = false,
|
||||
runState,
|
||||
taskId,
|
||||
}: RenderSenderOptions = {}) => {
|
||||
const notify = vi.fn()
|
||||
const onCompleted = vi.fn()
|
||||
const onRunStart = vi.fn()
|
||||
const onShowRes = vi.fn()
|
||||
|
||||
const hook = renderHook((props: { controlRetry: number, controlSend: number }) => useResultSender({
|
||||
appId: 'app-1',
|
||||
appSourceType,
|
||||
completionFiles: [],
|
||||
controlRetry: props.controlRetry,
|
||||
controlSend: props.controlSend,
|
||||
inputs,
|
||||
isCallBatchAPI: false,
|
||||
isPC,
|
||||
isWorkflow,
|
||||
notify,
|
||||
onCompleted,
|
||||
onRunStart,
|
||||
onShowRes,
|
||||
promptConfig,
|
||||
runState: runState || createRunStateHarness().runState,
|
||||
t: (key: string) => key,
|
||||
taskId,
|
||||
visionConfig,
|
||||
}), {
|
||||
initialProps: {
|
||||
controlRetry,
|
||||
controlSend,
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
...hook,
|
||||
notify,
|
||||
onCompleted,
|
||||
onRunStart,
|
||||
onShowRes,
|
||||
}
|
||||
}
|
||||
|
||||
describe('useResultSender', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
validateResultRequestMock.mockReturnValue({ canSend: true })
|
||||
buildResultRequestDataMock.mockReturnValue({ inputs: { name: 'Alice' } })
|
||||
createWorkflowStreamHandlersMock.mockReturnValue({ onWorkflowFinished: vi.fn() })
|
||||
sendCompletionMessageMock.mockResolvedValue(undefined)
|
||||
sendWorkflowMessageMock.mockResolvedValue(undefined)
|
||||
sleepMock.mockImplementation(() => new Promise<void>(() => {}))
|
||||
})
|
||||
|
||||
it('should reject sends while a response is already in progress', async () => {
|
||||
const { runState } = createRunStateHarness()
|
||||
runState.isResponding = true
|
||||
const { result, notify } = renderSender({ runState })
|
||||
|
||||
await act(async () => {
|
||||
expect(await result.current.handleSend()).toBe(false)
|
||||
})
|
||||
|
||||
expect(notify).toHaveBeenCalledWith({
|
||||
type: 'info',
|
||||
message: 'errorMessage.waitForResponse',
|
||||
})
|
||||
expect(validateResultRequestMock).not.toHaveBeenCalled()
|
||||
expect(sendCompletionMessageMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should surface validation failures without building request payloads', async () => {
|
||||
const { runState } = createRunStateHarness()
|
||||
validateResultRequestMock.mockReturnValue({
|
||||
canSend: false,
|
||||
notification: {
|
||||
type: 'error',
|
||||
message: 'invalid',
|
||||
},
|
||||
})
|
||||
|
||||
const { result, notify } = renderSender({ runState })
|
||||
|
||||
await act(async () => {
|
||||
expect(await result.current.handleSend()).toBe(false)
|
||||
})
|
||||
|
||||
expect(notify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'invalid',
|
||||
})
|
||||
expect(buildResultRequestDataMock).not.toHaveBeenCalled()
|
||||
expect(sendCompletionMessageMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should send completion requests when controlSend changes and process callbacks', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
let completionHandlers: CompletionHandlers | undefined
|
||||
|
||||
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
|
||||
completionHandlers = handlers as CompletionHandlers
|
||||
})
|
||||
|
||||
const { rerender, onCompleted, onRunStart, onShowRes } = renderSender({
|
||||
controlSend: 0,
|
||||
isPC: false,
|
||||
runState: harness.runState,
|
||||
taskId: 7,
|
||||
})
|
||||
|
||||
rerender({
|
||||
controlRetry: 0,
|
||||
controlSend: 1,
|
||||
})
|
||||
|
||||
expect(validateResultRequestMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||
inputs: { name: 'Alice' },
|
||||
isCallBatchAPI: false,
|
||||
}))
|
||||
expect(buildResultRequestDataMock).toHaveBeenCalled()
|
||||
expect(harness.runState.prepareForNewRun).toHaveBeenCalledTimes(1)
|
||||
expect(harness.runState.setRespondingTrue).toHaveBeenCalledTimes(1)
|
||||
expect(harness.runState.clearMoreLikeThis).toHaveBeenCalledTimes(1)
|
||||
expect(onShowRes).toHaveBeenCalledTimes(1)
|
||||
expect(onRunStart).toHaveBeenCalledTimes(1)
|
||||
expect(sendCompletionMessageMock).toHaveBeenCalledWith(
|
||||
{ inputs: { name: 'Alice' } },
|
||||
expect.objectContaining({
|
||||
onCompleted: expect.any(Function),
|
||||
onData: expect.any(Function),
|
||||
}),
|
||||
AppSourceTypeEnum.webApp,
|
||||
'app-1',
|
||||
)
|
||||
|
||||
const abortController = {} as AbortController
|
||||
expect(completionHandlers).toBeDefined()
|
||||
completionHandlers!.getAbortController(abortController)
|
||||
expect(harness.runState.abortControllerRef.current).toBe(abortController)
|
||||
|
||||
await act(async () => {
|
||||
completionHandlers!.onData('Hello', false, {
|
||||
messageId: 'message-1',
|
||||
taskId: 'task-1',
|
||||
})
|
||||
})
|
||||
|
||||
expect(harness.runState.setCurrentTaskId).toHaveBeenCalled()
|
||||
expect(harness.runState.currentTaskId).toBe('task-1')
|
||||
|
||||
await act(async () => {
|
||||
completionHandlers!.onMessageReplace({ answer: 'Replaced' })
|
||||
completionHandlers!.onCompleted()
|
||||
})
|
||||
|
||||
expect(harness.runState.setCompletionRes).toHaveBeenLastCalledWith('Replaced')
|
||||
expect(harness.runState.setRespondingFalse).toHaveBeenCalled()
|
||||
expect(harness.runState.resetRunState).toHaveBeenCalled()
|
||||
expect(harness.runState.setMessageId).toHaveBeenCalledWith('message-1')
|
||||
expect(onCompleted).toHaveBeenCalledWith('Replaced', 7, true)
|
||||
})
|
||||
|
||||
it('should trigger workflow sends on retry and report workflow request failures', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
sendWorkflowMessageMock.mockRejectedValue(new Error('workflow failed'))
|
||||
|
||||
const { rerender, notify } = renderSender({
|
||||
controlRetry: 0,
|
||||
isWorkflow: true,
|
||||
runState: harness.runState,
|
||||
})
|
||||
|
||||
rerender({
|
||||
controlRetry: 2,
|
||||
controlSend: 0,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(createWorkflowStreamHandlersMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||
getCompletionRes: harness.runState.getCompletionRes,
|
||||
resetRunState: harness.runState.resetRunState,
|
||||
setWorkflowProcessData: harness.runState.setWorkflowProcessData,
|
||||
}))
|
||||
expect(sendWorkflowMessageMock).toHaveBeenCalledWith(
|
||||
{ inputs: { name: 'Alice' } },
|
||||
expect.any(Object),
|
||||
AppSourceTypeEnum.webApp,
|
||||
'app-1',
|
||||
)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(harness.runState.setRespondingFalse).toHaveBeenCalled()
|
||||
expect(harness.runState.resetRunState).toHaveBeenCalled()
|
||||
expect(notify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'workflow failed',
|
||||
})
|
||||
})
|
||||
expect(harness.runState.clearMoreLikeThis).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should stringify non-Error workflow failures', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
sendWorkflowMessageMock.mockRejectedValue('workflow failed')
|
||||
|
||||
const { result, notify } = renderSender({
|
||||
isWorkflow: true,
|
||||
runState: harness.runState,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSend()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(notify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'workflow failed',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should timeout unfinished completion requests', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
sleepMock.mockResolvedValue(undefined)
|
||||
|
||||
const { result, onCompleted } = renderSender({
|
||||
runState: harness.runState,
|
||||
taskId: 9,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
expect(await result.current.handleSend()).toBe(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(harness.runState.setRespondingFalse).toHaveBeenCalled()
|
||||
expect(harness.runState.resetRunState).toHaveBeenCalled()
|
||||
expect(onCompleted).toHaveBeenCalledWith('', 9, false)
|
||||
})
|
||||
})
|
||||
|
||||
it('should ignore empty task ids and surface timeout warnings from stream callbacks', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
let completionHandlers: CompletionHandlers | undefined
|
||||
|
||||
sleepMock.mockResolvedValue(undefined)
|
||||
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
|
||||
completionHandlers = handlers as CompletionHandlers
|
||||
})
|
||||
|
||||
const { result, notify, onCompleted } = renderSender({
|
||||
runState: harness.runState,
|
||||
taskId: 11,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSend()
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
completionHandlers!.onData('Hello', false, {
|
||||
messageId: 'message-2',
|
||||
taskId: ' ',
|
||||
})
|
||||
completionHandlers!.onCompleted()
|
||||
completionHandlers!.onError()
|
||||
})
|
||||
|
||||
expect(harness.runState.currentTaskId).toBeNull()
|
||||
expect(notify).toHaveBeenNthCalledWith(1, {
|
||||
type: 'warning',
|
||||
message: 'warningMessage.timeoutExceeded',
|
||||
})
|
||||
expect(notify).toHaveBeenNthCalledWith(2, {
|
||||
type: 'warning',
|
||||
message: 'warningMessage.timeoutExceeded',
|
||||
})
|
||||
expect(onCompleted).toHaveBeenCalledWith('', 11, false)
|
||||
})
|
||||
|
||||
it('should avoid timeout fallback after a completion response has already ended', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
let resolveSleep!: () => void
|
||||
let completionHandlers: CompletionHandlers | undefined
|
||||
|
||||
sleepMock.mockImplementation(() => new Promise<void>((resolve) => {
|
||||
resolveSleep = resolve
|
||||
}))
|
||||
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
|
||||
completionHandlers = handlers as CompletionHandlers
|
||||
})
|
||||
|
||||
const { result, onCompleted } = renderSender({
|
||||
runState: harness.runState,
|
||||
taskId: 12,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSend()
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
harness.runState.setCompletionRes('Done')
|
||||
completionHandlers!.onCompleted()
|
||||
resolveSleep()
|
||||
await Promise.resolve()
|
||||
})
|
||||
|
||||
expect(onCompleted).toHaveBeenCalledWith('Done', 12, true)
|
||||
expect(onCompleted).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle non-timeout stream errors as failed completions', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
let completionHandlers: CompletionHandlers | undefined
|
||||
|
||||
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
|
||||
completionHandlers = handlers as CompletionHandlers
|
||||
})
|
||||
|
||||
const { result, onCompleted } = renderSender({
|
||||
runState: harness.runState,
|
||||
taskId: 13,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSend()
|
||||
completionHandlers!.onError()
|
||||
})
|
||||
|
||||
expect(harness.runState.setRespondingFalse).toHaveBeenCalled()
|
||||
expect(harness.runState.resetRunState).toHaveBeenCalled()
|
||||
expect(onCompleted).toHaveBeenCalledWith('', 13, false)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,237 @@
|
||||
import type { Dispatch, MutableRefObject, SetStateAction } from 'react'
|
||||
import type { FeedbackType } from '@/app/components/base/chat/chat/type'
|
||||
import type { WorkflowProcess } from '@/app/components/base/chat/types'
|
||||
import type { AppSourceType } from '@/service/share'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import { useCallback, useEffect, useReducer, useRef, useState } from 'react'
|
||||
import {
|
||||
stopChatMessageResponding,
|
||||
stopWorkflowMessage,
|
||||
updateFeedback,
|
||||
} from '@/service/share'
|
||||
|
||||
type Notify = (payload: { type: 'error', message: string }) => void
|
||||
|
||||
type RunControlState = {
|
||||
currentTaskId: string | null
|
||||
isStopping: boolean
|
||||
}
|
||||
|
||||
type RunControlAction
|
||||
= | { type: 'reset' }
|
||||
| { type: 'setCurrentTaskId', value: SetStateAction<string | null> }
|
||||
| { type: 'setIsStopping', value: SetStateAction<boolean> }
|
||||
|
||||
type UseResultRunStateOptions = {
|
||||
appId?: string
|
||||
appSourceType: AppSourceType
|
||||
controlStopResponding?: number
|
||||
isWorkflow: boolean
|
||||
notify: Notify
|
||||
onRunControlChange?: (control: { onStop: () => Promise<void> | void, isStopping: boolean } | null) => void
|
||||
}
|
||||
|
||||
export type ResultRunStateController = {
|
||||
abortControllerRef: MutableRefObject<AbortController | null>
|
||||
clearMoreLikeThis: () => void
|
||||
completionRes: string
|
||||
controlClearMoreLikeThis: number
|
||||
currentTaskId: string | null
|
||||
feedback: FeedbackType
|
||||
getCompletionRes: () => string
|
||||
getWorkflowProcessData: () => WorkflowProcess | undefined
|
||||
handleFeedback: (feedback: FeedbackType) => Promise<void>
|
||||
handleStop: () => Promise<void>
|
||||
isResponding: boolean
|
||||
isStopping: boolean
|
||||
messageId: string | null
|
||||
prepareForNewRun: () => void
|
||||
resetRunState: () => void
|
||||
setCompletionRes: (res: string) => void
|
||||
setCurrentTaskId: Dispatch<SetStateAction<string | null>>
|
||||
setIsStopping: Dispatch<SetStateAction<boolean>>
|
||||
setMessageId: Dispatch<SetStateAction<string | null>>
|
||||
setRespondingFalse: () => void
|
||||
setRespondingTrue: () => void
|
||||
setWorkflowProcessData: (data: WorkflowProcess | undefined) => void
|
||||
workflowProcessData: WorkflowProcess | undefined
|
||||
}
|
||||
|
||||
const runControlReducer = (state: RunControlState, action: RunControlAction): RunControlState => {
|
||||
switch (action.type) {
|
||||
case 'reset':
|
||||
return {
|
||||
currentTaskId: null,
|
||||
isStopping: false,
|
||||
}
|
||||
case 'setCurrentTaskId':
|
||||
return {
|
||||
...state,
|
||||
currentTaskId: typeof action.value === 'function' ? action.value(state.currentTaskId) : action.value,
|
||||
}
|
||||
case 'setIsStopping':
|
||||
return {
|
||||
...state,
|
||||
isStopping: typeof action.value === 'function' ? action.value(state.isStopping) : action.value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const useResultRunState = ({
|
||||
appId,
|
||||
appSourceType,
|
||||
controlStopResponding,
|
||||
isWorkflow,
|
||||
notify,
|
||||
onRunControlChange,
|
||||
}: UseResultRunStateOptions): ResultRunStateController => {
|
||||
const [isResponding, { setTrue: setRespondingTrue, setFalse: setRespondingFalse }] = useBoolean(false)
|
||||
const [completionResState, setCompletionResState] = useState<string>('')
|
||||
const completionResRef = useRef<string>('')
|
||||
const [workflowProcessDataState, setWorkflowProcessDataState] = useState<WorkflowProcess>()
|
||||
const workflowProcessDataRef = useRef<WorkflowProcess | undefined>(undefined)
|
||||
const [messageId, setMessageId] = useState<string | null>(null)
|
||||
const [feedback, setFeedback] = useState<FeedbackType>({
|
||||
rating: null,
|
||||
})
|
||||
const [controlClearMoreLikeThis, setControlClearMoreLikeThis] = useState(0)
|
||||
const abortControllerRef = useRef<AbortController | null>(null)
|
||||
const [{ currentTaskId, isStopping }, dispatchRunControl] = useReducer(runControlReducer, {
|
||||
currentTaskId: null,
|
||||
isStopping: false,
|
||||
})
|
||||
|
||||
const setCurrentTaskId = useCallback<Dispatch<SetStateAction<string | null>>>((value) => {
|
||||
dispatchRunControl({
|
||||
type: 'setCurrentTaskId',
|
||||
value,
|
||||
})
|
||||
}, [])
|
||||
|
||||
const setIsStopping = useCallback<Dispatch<SetStateAction<boolean>>>((value) => {
|
||||
dispatchRunControl({
|
||||
type: 'setIsStopping',
|
||||
value,
|
||||
})
|
||||
}, [])
|
||||
|
||||
const setCompletionRes = useCallback((res: string) => {
|
||||
completionResRef.current = res
|
||||
setCompletionResState(res)
|
||||
}, [])
|
||||
|
||||
const getCompletionRes = useCallback(() => completionResRef.current, [])
|
||||
|
||||
const setWorkflowProcessData = useCallback((data: WorkflowProcess | undefined) => {
|
||||
workflowProcessDataRef.current = data
|
||||
setWorkflowProcessDataState(data)
|
||||
}, [])
|
||||
|
||||
const getWorkflowProcessData = useCallback(() => workflowProcessDataRef.current, [])
|
||||
|
||||
const resetRunState = useCallback(() => {
|
||||
dispatchRunControl({ type: 'reset' })
|
||||
abortControllerRef.current = null
|
||||
onRunControlChange?.(null)
|
||||
}, [onRunControlChange])
|
||||
|
||||
const prepareForNewRun = useCallback(() => {
|
||||
setMessageId(null)
|
||||
setFeedback({ rating: null })
|
||||
setCompletionRes('')
|
||||
setWorkflowProcessData(undefined)
|
||||
resetRunState()
|
||||
}, [resetRunState, setCompletionRes, setWorkflowProcessData])
|
||||
|
||||
const handleFeedback = useCallback(async (nextFeedback: FeedbackType) => {
|
||||
await updateFeedback({
|
||||
url: `/messages/${messageId}/feedbacks`,
|
||||
body: {
|
||||
rating: nextFeedback.rating,
|
||||
content: nextFeedback.content,
|
||||
},
|
||||
}, appSourceType, appId)
|
||||
setFeedback(nextFeedback)
|
||||
}, [appId, appSourceType, messageId])
|
||||
|
||||
const handleStop = useCallback(async () => {
|
||||
if (!currentTaskId || isStopping)
|
||||
return
|
||||
|
||||
setIsStopping(true)
|
||||
try {
|
||||
if (isWorkflow)
|
||||
await stopWorkflowMessage(appId!, currentTaskId, appSourceType, appId || '')
|
||||
else
|
||||
await stopChatMessageResponding(appId!, currentTaskId, appSourceType, appId || '')
|
||||
|
||||
abortControllerRef.current?.abort()
|
||||
}
|
||||
catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
notify({ type: 'error', message })
|
||||
}
|
||||
finally {
|
||||
setIsStopping(false)
|
||||
}
|
||||
}, [appId, appSourceType, currentTaskId, isStopping, isWorkflow, notify, setIsStopping])
|
||||
|
||||
const clearMoreLikeThis = useCallback(() => {
|
||||
setControlClearMoreLikeThis(Date.now())
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
const abortCurrentRequest = () => {
|
||||
abortControllerRef.current?.abort()
|
||||
}
|
||||
|
||||
if (controlStopResponding) {
|
||||
abortCurrentRequest()
|
||||
setRespondingFalse()
|
||||
resetRunState()
|
||||
}
|
||||
|
||||
return abortCurrentRequest
|
||||
}, [controlStopResponding, resetRunState, setRespondingFalse])
|
||||
|
||||
useEffect(() => {
|
||||
if (!onRunControlChange)
|
||||
return
|
||||
|
||||
if (isResponding && currentTaskId) {
|
||||
onRunControlChange({
|
||||
onStop: handleStop,
|
||||
isStopping,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
onRunControlChange(null)
|
||||
}, [currentTaskId, handleStop, isResponding, isStopping, onRunControlChange])
|
||||
|
||||
return {
|
||||
abortControllerRef,
|
||||
clearMoreLikeThis,
|
||||
completionRes: completionResState,
|
||||
controlClearMoreLikeThis,
|
||||
currentTaskId,
|
||||
feedback,
|
||||
getCompletionRes,
|
||||
getWorkflowProcessData,
|
||||
handleFeedback,
|
||||
handleStop,
|
||||
isResponding,
|
||||
isStopping,
|
||||
messageId,
|
||||
prepareForNewRun,
|
||||
resetRunState,
|
||||
setCompletionRes,
|
||||
setCurrentTaskId,
|
||||
setIsStopping,
|
||||
setMessageId,
|
||||
setRespondingFalse,
|
||||
setRespondingTrue,
|
||||
setWorkflowProcessData,
|
||||
workflowProcessData: workflowProcessDataState,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,230 @@
|
||||
import type { ResultInputValue } from '../result-request'
|
||||
import type { ResultRunStateController } from './use-result-run-state'
|
||||
import type { PromptConfig } from '@/models/debug'
|
||||
import type { AppSourceType } from '@/service/share'
|
||||
import type { VisionFile, VisionSettings } from '@/types/app'
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
import { TEXT_GENERATION_TIMEOUT_MS } from '@/config'
|
||||
import {
|
||||
sendCompletionMessage,
|
||||
sendWorkflowMessage,
|
||||
} from '@/service/share'
|
||||
import { sleep } from '@/utils'
|
||||
import { buildResultRequestData, validateResultRequest } from '../result-request'
|
||||
import { createWorkflowStreamHandlers } from '../workflow-stream-handlers'
|
||||
|
||||
type Notify = (payload: { type: 'error' | 'info' | 'warning', message: string }) => void
|
||||
type Translate = (key: string, options?: Record<string, unknown>) => string
|
||||
|
||||
type UseResultSenderOptions = {
|
||||
appId?: string
|
||||
appSourceType: AppSourceType
|
||||
completionFiles: VisionFile[]
|
||||
controlRetry?: number
|
||||
controlSend?: number
|
||||
inputs: Record<string, ResultInputValue>
|
||||
isCallBatchAPI: boolean
|
||||
isPC: boolean
|
||||
isWorkflow: boolean
|
||||
notify: Notify
|
||||
onCompleted: (completionRes: string, taskId?: number, success?: boolean) => void
|
||||
onRunStart: () => void
|
||||
onShowRes: () => void
|
||||
promptConfig: PromptConfig | null
|
||||
runState: ResultRunStateController
|
||||
t: Translate
|
||||
taskId?: number
|
||||
visionConfig: VisionSettings
|
||||
}
|
||||
|
||||
const logRequestError = (notify: Notify, error: unknown) => {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
notify({ type: 'error', message })
|
||||
}
|
||||
|
||||
export const useResultSender = ({
|
||||
appId,
|
||||
appSourceType,
|
||||
completionFiles,
|
||||
controlRetry,
|
||||
controlSend,
|
||||
inputs,
|
||||
isCallBatchAPI,
|
||||
isPC,
|
||||
isWorkflow,
|
||||
notify,
|
||||
onCompleted,
|
||||
onRunStart,
|
||||
onShowRes,
|
||||
promptConfig,
|
||||
runState,
|
||||
t,
|
||||
taskId,
|
||||
visionConfig,
|
||||
}: UseResultSenderOptions) => {
|
||||
const { clearMoreLikeThis } = runState
|
||||
|
||||
const handleSend = useCallback(async () => {
|
||||
if (runState.isResponding) {
|
||||
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
|
||||
return false
|
||||
}
|
||||
|
||||
const validation = validateResultRequest({
|
||||
completionFiles,
|
||||
inputs,
|
||||
isCallBatchAPI,
|
||||
promptConfig,
|
||||
t,
|
||||
})
|
||||
if (!validation.canSend) {
|
||||
notify(validation.notification!)
|
||||
return false
|
||||
}
|
||||
|
||||
const data = buildResultRequestData({
|
||||
completionFiles,
|
||||
inputs,
|
||||
promptConfig,
|
||||
visionConfig,
|
||||
})
|
||||
|
||||
runState.prepareForNewRun()
|
||||
|
||||
if (!isPC) {
|
||||
onShowRes()
|
||||
onRunStart()
|
||||
}
|
||||
|
||||
runState.setRespondingTrue()
|
||||
|
||||
let isEnd = false
|
||||
let isTimeout = false
|
||||
let completionChunks: string[] = []
|
||||
let tempMessageId = ''
|
||||
|
||||
void (async () => {
|
||||
await sleep(TEXT_GENERATION_TIMEOUT_MS)
|
||||
if (!isEnd) {
|
||||
runState.setRespondingFalse()
|
||||
onCompleted(runState.getCompletionRes(), taskId, false)
|
||||
runState.resetRunState()
|
||||
isTimeout = true
|
||||
}
|
||||
})()
|
||||
|
||||
if (isWorkflow) {
|
||||
const otherOptions = createWorkflowStreamHandlers({
|
||||
getCompletionRes: runState.getCompletionRes,
|
||||
getWorkflowProcessData: runState.getWorkflowProcessData,
|
||||
isTimedOut: () => isTimeout,
|
||||
markEnded: () => {
|
||||
isEnd = true
|
||||
},
|
||||
notify,
|
||||
onCompleted,
|
||||
resetRunState: runState.resetRunState,
|
||||
setCompletionRes: runState.setCompletionRes,
|
||||
setCurrentTaskId: runState.setCurrentTaskId,
|
||||
setIsStopping: runState.setIsStopping,
|
||||
setMessageId: runState.setMessageId,
|
||||
setRespondingFalse: runState.setRespondingFalse,
|
||||
setWorkflowProcessData: runState.setWorkflowProcessData,
|
||||
t,
|
||||
taskId,
|
||||
})
|
||||
|
||||
void sendWorkflowMessage(data, otherOptions, appSourceType, appId).catch((error) => {
|
||||
runState.setRespondingFalse()
|
||||
runState.resetRunState()
|
||||
logRequestError(notify, error)
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
void sendCompletionMessage(data, {
|
||||
onData: (chunk, _isFirstMessage, { messageId, taskId: nextTaskId }) => {
|
||||
tempMessageId = messageId
|
||||
if (nextTaskId && nextTaskId.trim() !== '')
|
||||
runState.setCurrentTaskId(prev => prev ?? nextTaskId)
|
||||
|
||||
completionChunks.push(chunk)
|
||||
runState.setCompletionRes(completionChunks.join(''))
|
||||
},
|
||||
onCompleted: () => {
|
||||
if (isTimeout) {
|
||||
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
|
||||
return
|
||||
}
|
||||
|
||||
runState.setRespondingFalse()
|
||||
runState.resetRunState()
|
||||
runState.setMessageId(tempMessageId)
|
||||
onCompleted(runState.getCompletionRes(), taskId, true)
|
||||
isEnd = true
|
||||
},
|
||||
onMessageReplace: (messageReplace) => {
|
||||
completionChunks = [messageReplace.answer]
|
||||
runState.setCompletionRes(completionChunks.join(''))
|
||||
},
|
||||
onError: () => {
|
||||
if (isTimeout) {
|
||||
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
|
||||
return
|
||||
}
|
||||
|
||||
runState.setRespondingFalse()
|
||||
runState.resetRunState()
|
||||
onCompleted(runState.getCompletionRes(), taskId, false)
|
||||
isEnd = true
|
||||
},
|
||||
getAbortController: (abortController) => {
|
||||
runState.abortControllerRef.current = abortController
|
||||
},
|
||||
}, appSourceType, appId)
|
||||
|
||||
return true
|
||||
}, [
|
||||
appId,
|
||||
appSourceType,
|
||||
completionFiles,
|
||||
inputs,
|
||||
isCallBatchAPI,
|
||||
isPC,
|
||||
isWorkflow,
|
||||
notify,
|
||||
onCompleted,
|
||||
onRunStart,
|
||||
onShowRes,
|
||||
promptConfig,
|
||||
runState,
|
||||
t,
|
||||
taskId,
|
||||
visionConfig,
|
||||
])
|
||||
|
||||
const handleSendRef = useRef(handleSend)
|
||||
|
||||
useEffect(() => {
|
||||
handleSendRef.current = handleSend
|
||||
}, [handleSend])
|
||||
|
||||
useEffect(() => {
|
||||
if (!controlSend)
|
||||
return
|
||||
|
||||
void handleSendRef.current()
|
||||
clearMoreLikeThis()
|
||||
}, [clearMoreLikeThis, controlSend])
|
||||
|
||||
useEffect(() => {
|
||||
if (!controlRetry)
|
||||
return
|
||||
|
||||
void handleSendRef.current()
|
||||
}, [controlRetry])
|
||||
|
||||
return {
|
||||
handleSend,
|
||||
}
|
||||
}
|
||||
@ -1,46 +1,18 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { FeedbackType } from '@/app/components/base/chat/chat/type'
|
||||
import type { WorkflowProcess } from '@/app/components/base/chat/types'
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import type { PromptConfig } from '@/models/debug'
|
||||
import type { SiteInfo } from '@/models/share'
|
||||
import type {
|
||||
IOtherOptions,
|
||||
} from '@/service/base'
|
||||
import type { AppSourceType } from '@/service/share'
|
||||
import type { VisionFile, VisionSettings } from '@/types/app'
|
||||
import { RiLoader2Line } from '@remixicon/react'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import { t } from 'i18next'
|
||||
import { produce } from 'immer'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import TextGenerationRes from '@/app/components/app/text-generate/item'
|
||||
import Button from '@/app/components/base/button'
|
||||
import {
|
||||
getFilesInLogs,
|
||||
getProcessedFiles,
|
||||
} from '@/app/components/base/file-uploader/utils'
|
||||
import { StopCircle } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import NoData from '@/app/components/share/text-generation/no-data'
|
||||
import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types'
|
||||
import { TEXT_GENERATION_TIMEOUT_MS } from '@/config'
|
||||
import {
|
||||
sseGet,
|
||||
} from '@/service/base'
|
||||
import {
|
||||
AppSourceType,
|
||||
sendCompletionMessage,
|
||||
sendWorkflowMessage,
|
||||
stopChatMessageResponding,
|
||||
stopWorkflowMessage,
|
||||
updateFeedback,
|
||||
} from '@/service/share'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { sleep } from '@/utils'
|
||||
import { formatBooleanInputs } from '@/utils/model-config'
|
||||
import { useResultRunState } from './hooks/use-result-run-state'
|
||||
import { useResultSender } from './hooks/use-result-sender'
|
||||
|
||||
export type IResultProps = {
|
||||
isWorkflow: boolean
|
||||
@ -95,554 +67,52 @@ const Result: FC<IResultProps> = ({
|
||||
onRunControlChange,
|
||||
hideInlineStopButton = false,
|
||||
}) => {
|
||||
const [isResponding, { setTrue: setRespondingTrue, setFalse: setRespondingFalse }] = useBoolean(false)
|
||||
const [completionRes, doSetCompletionRes] = useState<string>('')
|
||||
const completionResRef = useRef<string>('')
|
||||
const setCompletionRes = (res: string) => {
|
||||
completionResRef.current = res
|
||||
doSetCompletionRes(res)
|
||||
}
|
||||
const getCompletionRes = () => completionResRef.current
|
||||
const [workflowProcessData, doSetWorkflowProcessData] = useState<WorkflowProcess>()
|
||||
const workflowProcessDataRef = useRef<WorkflowProcess | undefined>(undefined)
|
||||
const setWorkflowProcessData = useCallback((data: WorkflowProcess | undefined) => {
|
||||
workflowProcessDataRef.current = data
|
||||
doSetWorkflowProcessData(data)
|
||||
}, [])
|
||||
const getWorkflowProcessData = () => workflowProcessDataRef.current
|
||||
const [currentTaskId, setCurrentTaskId] = useState<string | null>(null)
|
||||
const [isStopping, setIsStopping] = useState(false)
|
||||
const abortControllerRef = useRef<AbortController | null>(null)
|
||||
const resetRunState = useCallback(() => {
|
||||
setCurrentTaskId(null)
|
||||
setIsStopping(false)
|
||||
abortControllerRef.current = null
|
||||
onRunControlChange?.(null)
|
||||
}, [onRunControlChange])
|
||||
|
||||
useEffect(() => {
|
||||
const abortCurrentRequest = () => {
|
||||
abortControllerRef.current?.abort()
|
||||
}
|
||||
|
||||
if (controlStopResponding) {
|
||||
abortCurrentRequest()
|
||||
setRespondingFalse()
|
||||
resetRunState()
|
||||
}
|
||||
|
||||
return abortCurrentRequest
|
||||
}, [controlStopResponding, resetRunState, setRespondingFalse])
|
||||
|
||||
const { notify } = Toast
|
||||
const isNoData = !completionRes
|
||||
|
||||
const [messageId, setMessageId] = useState<string | null>(null)
|
||||
const [feedback, setFeedback] = useState<FeedbackType>({
|
||||
rating: null,
|
||||
const runState = useResultRunState({
|
||||
appId,
|
||||
appSourceType,
|
||||
controlStopResponding,
|
||||
isWorkflow,
|
||||
notify,
|
||||
onRunControlChange,
|
||||
})
|
||||
|
||||
const handleFeedback = async (feedback: FeedbackType) => {
|
||||
await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating, content: feedback.content } }, appSourceType, appId)
|
||||
setFeedback(feedback)
|
||||
}
|
||||
const { handleSend } = useResultSender({
|
||||
appId,
|
||||
appSourceType,
|
||||
completionFiles,
|
||||
controlRetry,
|
||||
controlSend,
|
||||
inputs,
|
||||
isCallBatchAPI,
|
||||
isPC,
|
||||
isWorkflow,
|
||||
notify,
|
||||
onCompleted,
|
||||
onRunStart,
|
||||
onShowRes,
|
||||
promptConfig,
|
||||
runState,
|
||||
t,
|
||||
taskId,
|
||||
visionConfig,
|
||||
})
|
||||
|
||||
const logError = (message: string) => {
|
||||
notify({ type: 'error', message })
|
||||
}
|
||||
|
||||
const handleStop = useCallback(async () => {
|
||||
if (!currentTaskId || isStopping)
|
||||
return
|
||||
setIsStopping(true)
|
||||
try {
|
||||
if (isWorkflow)
|
||||
await stopWorkflowMessage(appId!, currentTaskId, appSourceType, appId || '')
|
||||
else
|
||||
await stopChatMessageResponding(appId!, currentTaskId, appSourceType, appId || '')
|
||||
abortControllerRef.current?.abort()
|
||||
}
|
||||
catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
notify({ type: 'error', message })
|
||||
}
|
||||
finally {
|
||||
setIsStopping(false)
|
||||
}
|
||||
}, [appId, currentTaskId, appSourceType, isStopping, isWorkflow, notify])
|
||||
|
||||
useEffect(() => {
|
||||
if (!onRunControlChange)
|
||||
return
|
||||
if (isResponding && currentTaskId) {
|
||||
onRunControlChange({
|
||||
onStop: handleStop,
|
||||
isStopping,
|
||||
})
|
||||
}
|
||||
else {
|
||||
onRunControlChange(null)
|
||||
}
|
||||
}, [currentTaskId, handleStop, isResponding, isStopping, onRunControlChange])
|
||||
|
||||
const checkCanSend = () => {
|
||||
// batch will check outer
|
||||
if (isCallBatchAPI)
|
||||
return true
|
||||
|
||||
const prompt_variables = promptConfig?.prompt_variables
|
||||
if (!prompt_variables || prompt_variables?.length === 0) {
|
||||
if (completionFiles.some(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) {
|
||||
notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) })
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
let hasEmptyInput = ''
|
||||
const requiredVars = prompt_variables?.filter(({ key, name, required, type }) => {
|
||||
if (type === 'boolean' || type === 'checkbox')
|
||||
return false // boolean/checkbox input is not required
|
||||
const res = (!key || !key.trim()) || (!name || !name.trim()) || (required || required === undefined || required === null)
|
||||
return res
|
||||
}) || [] // compatible with old version
|
||||
requiredVars.forEach(({ key, name }) => {
|
||||
if (hasEmptyInput)
|
||||
return
|
||||
|
||||
if (!inputs[key])
|
||||
hasEmptyInput = name
|
||||
})
|
||||
|
||||
if (hasEmptyInput) {
|
||||
logError(t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: hasEmptyInput }))
|
||||
return false
|
||||
}
|
||||
|
||||
if (completionFiles.some(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) {
|
||||
notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) })
|
||||
return false
|
||||
}
|
||||
return !hasEmptyInput
|
||||
}
|
||||
|
||||
const handleSend = async () => {
|
||||
if (isResponding) {
|
||||
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
|
||||
return false
|
||||
}
|
||||
|
||||
if (!checkCanSend())
|
||||
return
|
||||
|
||||
// Process inputs: convert file entities to API format
|
||||
const processedInputs = { ...formatBooleanInputs(promptConfig?.prompt_variables, inputs) }
|
||||
promptConfig?.prompt_variables.forEach((variable) => {
|
||||
const value = processedInputs[variable.key]
|
||||
if (variable.type === 'file' && value && typeof value === 'object' && !Array.isArray(value)) {
|
||||
// Convert single file entity to API format
|
||||
processedInputs[variable.key] = getProcessedFiles([value as FileEntity])[0]
|
||||
}
|
||||
else if (variable.type === 'file-list' && Array.isArray(value) && value.length > 0) {
|
||||
// Convert file entity array to API format
|
||||
processedInputs[variable.key] = getProcessedFiles(value as FileEntity[])
|
||||
}
|
||||
})
|
||||
|
||||
const data: Record<string, any> = {
|
||||
inputs: processedInputs,
|
||||
}
|
||||
if (visionConfig.enabled && completionFiles && completionFiles?.length > 0) {
|
||||
data.files = completionFiles.map((item) => {
|
||||
if (item.transfer_method === TransferMethod.local_file) {
|
||||
return {
|
||||
...item,
|
||||
url: '',
|
||||
}
|
||||
}
|
||||
return item
|
||||
})
|
||||
}
|
||||
|
||||
setMessageId(null)
|
||||
setFeedback({
|
||||
rating: null,
|
||||
})
|
||||
setCompletionRes('')
|
||||
setWorkflowProcessData(undefined)
|
||||
resetRunState()
|
||||
|
||||
let res: string[] = []
|
||||
let tempMessageId = ''
|
||||
|
||||
if (!isPC) {
|
||||
onShowRes()
|
||||
onRunStart()
|
||||
}
|
||||
|
||||
setRespondingTrue()
|
||||
let isEnd = false
|
||||
let isTimeout = false;
|
||||
(async () => {
|
||||
await sleep(TEXT_GENERATION_TIMEOUT_MS)
|
||||
if (!isEnd) {
|
||||
setRespondingFalse()
|
||||
onCompleted(getCompletionRes(), taskId, false)
|
||||
resetRunState()
|
||||
isTimeout = true
|
||||
}
|
||||
})()
|
||||
|
||||
if (isWorkflow) {
|
||||
const otherOptions: IOtherOptions = {
|
||||
isPublicAPI: appSourceType === AppSourceType.webApp,
|
||||
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
|
||||
const workflowProcessData = getWorkflowProcessData()
|
||||
if (workflowProcessData && workflowProcessData.tracing.length > 0) {
|
||||
setWorkflowProcessData(produce(workflowProcessData, (draft) => {
|
||||
draft.expand = true
|
||||
draft.status = WorkflowRunningStatus.Running
|
||||
}))
|
||||
}
|
||||
else {
|
||||
tempMessageId = workflow_run_id
|
||||
setCurrentTaskId(task_id || null)
|
||||
setIsStopping(false)
|
||||
setWorkflowProcessData({
|
||||
status: WorkflowRunningStatus.Running,
|
||||
tracing: [],
|
||||
expand: false,
|
||||
resultText: '',
|
||||
})
|
||||
}
|
||||
},
|
||||
onIterationStart: ({ data }) => {
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.expand = true
|
||||
draft.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
expand: true,
|
||||
})
|
||||
}))
|
||||
},
|
||||
onIterationNext: () => {
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.expand = true
|
||||
const iterations = draft.tracing.find(item => item.node_id === data.node_id
|
||||
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
|
||||
iterations?.details!.push([])
|
||||
}))
|
||||
},
|
||||
onIterationFinish: ({ data }) => {
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.expand = true
|
||||
const iterationsIndex = draft.tracing.findIndex(item => item.node_id === data.node_id
|
||||
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
|
||||
draft.tracing[iterationsIndex] = {
|
||||
...data,
|
||||
expand: !!data.error,
|
||||
}
|
||||
}))
|
||||
},
|
||||
onLoopStart: ({ data }) => {
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.expand = true
|
||||
draft.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
expand: true,
|
||||
})
|
||||
}))
|
||||
},
|
||||
onLoopNext: () => {
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.expand = true
|
||||
const loops = draft.tracing.find(item => item.node_id === data.node_id
|
||||
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
|
||||
loops?.details!.push([])
|
||||
}))
|
||||
},
|
||||
onLoopFinish: ({ data }) => {
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.expand = true
|
||||
const loopsIndex = draft.tracing.findIndex(item => item.node_id === data.node_id
|
||||
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
|
||||
draft.tracing[loopsIndex] = {
|
||||
...data,
|
||||
expand: !!data.error,
|
||||
}
|
||||
}))
|
||||
},
|
||||
onNodeStarted: ({ data }) => {
|
||||
if (data.iteration_id)
|
||||
return
|
||||
|
||||
if (data.loop_id)
|
||||
return
|
||||
const workflowProcessData = getWorkflowProcessData()
|
||||
setWorkflowProcessData(produce(workflowProcessData!, (draft) => {
|
||||
if (draft.tracing.length > 0) {
|
||||
const currentIndex = draft.tracing.findIndex(item => item.node_id === data.node_id)
|
||||
if (currentIndex > -1) {
|
||||
draft.expand = true
|
||||
draft.tracing![currentIndex] = {
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
expand: true,
|
||||
}
|
||||
}
|
||||
else {
|
||||
draft.expand = true
|
||||
draft.tracing.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
expand: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
else {
|
||||
draft.expand = true
|
||||
draft.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
expand: true,
|
||||
})
|
||||
}
|
||||
}))
|
||||
},
|
||||
onNodeFinished: ({ data }) => {
|
||||
if (data.iteration_id)
|
||||
return
|
||||
|
||||
if (data.loop_id)
|
||||
return
|
||||
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id
|
||||
&& (trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || trace.parallel_id === data.execution_metadata?.parallel_id))
|
||||
if (currentIndex > -1 && draft.tracing) {
|
||||
draft.tracing[currentIndex] = {
|
||||
...(draft.tracing[currentIndex].extras
|
||||
? { extras: draft.tracing[currentIndex].extras }
|
||||
: {}),
|
||||
...data,
|
||||
expand: !!data.error,
|
||||
}
|
||||
}
|
||||
}))
|
||||
},
|
||||
onWorkflowFinished: ({ data }) => {
|
||||
if (isTimeout) {
|
||||
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
|
||||
return
|
||||
}
|
||||
const workflowStatus = data.status as WorkflowRunningStatus | undefined
|
||||
const markNodesStopped = (traces?: WorkflowProcess['tracing']) => {
|
||||
if (!traces)
|
||||
return
|
||||
const markTrace = (trace: WorkflowProcess['tracing'][number]) => {
|
||||
if ([NodeRunningStatus.Running, NodeRunningStatus.Waiting].includes(trace.status as NodeRunningStatus))
|
||||
trace.status = NodeRunningStatus.Stopped
|
||||
trace.details?.forEach(detailGroup => detailGroup.forEach(markTrace))
|
||||
trace.retryDetail?.forEach(markTrace)
|
||||
trace.parallelDetail?.children?.forEach(markTrace)
|
||||
}
|
||||
traces.forEach(markTrace)
|
||||
}
|
||||
if (workflowStatus === WorkflowRunningStatus.Stopped) {
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.status = WorkflowRunningStatus.Stopped
|
||||
markNodesStopped(draft.tracing)
|
||||
}))
|
||||
setRespondingFalse()
|
||||
resetRunState()
|
||||
onCompleted(getCompletionRes(), taskId, false)
|
||||
isEnd = true
|
||||
return
|
||||
}
|
||||
if (data.error) {
|
||||
notify({ type: 'error', message: data.error })
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.status = WorkflowRunningStatus.Failed
|
||||
markNodesStopped(draft.tracing)
|
||||
}))
|
||||
setRespondingFalse()
|
||||
resetRunState()
|
||||
onCompleted(getCompletionRes(), taskId, false)
|
||||
isEnd = true
|
||||
return
|
||||
}
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.status = WorkflowRunningStatus.Succeeded
|
||||
draft.files = getFilesInLogs(data.outputs || []) as any[]
|
||||
}))
|
||||
if (!data.outputs) {
|
||||
setCompletionRes('')
|
||||
}
|
||||
else {
|
||||
setCompletionRes(data.outputs)
|
||||
const isStringOutput = Object.keys(data.outputs).length === 1 && typeof data.outputs[Object.keys(data.outputs)[0]] === 'string'
|
||||
if (isStringOutput) {
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.resultText = data.outputs[Object.keys(data.outputs)[0]]
|
||||
}))
|
||||
}
|
||||
}
|
||||
setRespondingFalse()
|
||||
resetRunState()
|
||||
setMessageId(tempMessageId)
|
||||
onCompleted(getCompletionRes(), taskId, true)
|
||||
isEnd = true
|
||||
},
|
||||
onTextChunk: (params) => {
|
||||
const { data: { text } } = params
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.resultText += text
|
||||
}))
|
||||
},
|
||||
onTextReplace: (params) => {
|
||||
const { data: { text } } = params
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.resultText = text
|
||||
}))
|
||||
},
|
||||
onHumanInputRequired: ({ data: humanInputRequiredData }) => {
|
||||
const workflowProcessData = getWorkflowProcessData()
|
||||
setWorkflowProcessData(produce(workflowProcessData!, (draft) => {
|
||||
if (!draft.humanInputFormDataList) {
|
||||
draft.humanInputFormDataList = [humanInputRequiredData]
|
||||
}
|
||||
else {
|
||||
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === humanInputRequiredData.node_id)
|
||||
if (currentFormIndex > -1) {
|
||||
draft.humanInputFormDataList[currentFormIndex] = humanInputRequiredData
|
||||
}
|
||||
else {
|
||||
draft.humanInputFormDataList.push(humanInputRequiredData)
|
||||
}
|
||||
}
|
||||
const currentIndex = draft.tracing!.findIndex(item => item.node_id === humanInputRequiredData.node_id)
|
||||
if (currentIndex > -1) {
|
||||
draft.tracing![currentIndex].status = NodeRunningStatus.Paused
|
||||
}
|
||||
}))
|
||||
},
|
||||
onHumanInputFormFilled: ({ data: humanInputFilledFormData }) => {
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
if (draft.humanInputFormDataList?.length) {
|
||||
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === humanInputFilledFormData.node_id)
|
||||
draft.humanInputFormDataList.splice(currentFormIndex, 1)
|
||||
}
|
||||
if (!draft.humanInputFilledFormDataList) {
|
||||
draft.humanInputFilledFormDataList = [humanInputFilledFormData]
|
||||
}
|
||||
else {
|
||||
draft.humanInputFilledFormDataList.push(humanInputFilledFormData)
|
||||
}
|
||||
}))
|
||||
},
|
||||
onHumanInputFormTimeout: ({ data: humanInputFormTimeoutData }) => {
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
if (draft.humanInputFormDataList?.length) {
|
||||
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === humanInputFormTimeoutData.node_id)
|
||||
draft.humanInputFormDataList[currentFormIndex].expiration_time = humanInputFormTimeoutData.expiration_time
|
||||
}
|
||||
}))
|
||||
},
|
||||
onWorkflowPaused: ({ data: workflowPausedData }) => {
|
||||
tempMessageId = workflowPausedData.workflow_run_id
|
||||
const url = `/workflow/${workflowPausedData.workflow_run_id}/events`
|
||||
sseGet(
|
||||
url,
|
||||
{},
|
||||
otherOptions,
|
||||
)
|
||||
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
|
||||
draft.expand = false
|
||||
draft.status = WorkflowRunningStatus.Paused
|
||||
}))
|
||||
},
|
||||
}
|
||||
sendWorkflowMessage(
|
||||
data,
|
||||
otherOptions,
|
||||
appSourceType,
|
||||
appId,
|
||||
).catch((error) => {
|
||||
setRespondingFalse()
|
||||
resetRunState()
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
notify({ type: 'error', message })
|
||||
})
|
||||
}
|
||||
else {
|
||||
sendCompletionMessage(data, {
|
||||
onData: (data: string, _isFirstMessage: boolean, { messageId, taskId }) => {
|
||||
tempMessageId = messageId
|
||||
if (taskId && typeof taskId === 'string' && taskId.trim() !== '')
|
||||
setCurrentTaskId(prev => prev ?? taskId)
|
||||
res.push(data)
|
||||
setCompletionRes(res.join(''))
|
||||
},
|
||||
onCompleted: () => {
|
||||
if (isTimeout) {
|
||||
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
|
||||
return
|
||||
}
|
||||
setRespondingFalse()
|
||||
resetRunState()
|
||||
setMessageId(tempMessageId)
|
||||
onCompleted(getCompletionRes(), taskId, true)
|
||||
isEnd = true
|
||||
},
|
||||
onMessageReplace: (messageReplace) => {
|
||||
res = [messageReplace.answer]
|
||||
setCompletionRes(res.join(''))
|
||||
},
|
||||
onError() {
|
||||
if (isTimeout) {
|
||||
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
|
||||
return
|
||||
}
|
||||
setRespondingFalse()
|
||||
resetRunState()
|
||||
onCompleted(getCompletionRes(), taskId, false)
|
||||
isEnd = true
|
||||
},
|
||||
getAbortController: (abortController) => {
|
||||
abortControllerRef.current = abortController
|
||||
},
|
||||
}, appSourceType, appId)
|
||||
}
|
||||
}
|
||||
|
||||
const [controlClearMoreLikeThis, setControlClearMoreLikeThis] = useState(0)
|
||||
useEffect(() => {
|
||||
if (controlSend) {
|
||||
handleSend()
|
||||
setControlClearMoreLikeThis(Date.now())
|
||||
}
|
||||
}, [controlSend])
|
||||
|
||||
useEffect(() => {
|
||||
if (controlRetry)
|
||||
handleSend()
|
||||
}, [controlRetry])
|
||||
const isNoData = !runState.completionRes
|
||||
|
||||
const renderTextGenerationRes = () => (
|
||||
<>
|
||||
{!hideInlineStopButton && isResponding && currentTaskId && (
|
||||
{!hideInlineStopButton && runState.isResponding && runState.currentTaskId && (
|
||||
<div className={`mb-3 flex ${isPC ? 'justify-end' : 'justify-center'}`}>
|
||||
<Button
|
||||
variant="secondary"
|
||||
disabled={isStopping}
|
||||
onClick={handleStop}
|
||||
disabled={runState.isStopping}
|
||||
onClick={runState.handleStop}
|
||||
>
|
||||
{
|
||||
isStopping
|
||||
? <RiLoader2Line className="mr-[5px] h-3.5 w-3.5 animate-spin" />
|
||||
: <StopCircle className="mr-[5px] h-3.5 w-3.5" />
|
||||
runState.isStopping
|
||||
? <span aria-hidden className="i-ri-loader-2-line mr-[5px] h-3.5 w-3.5 animate-spin" />
|
||||
: <span aria-hidden className="i-ri-stop-circle-fill mr-[5px] h-3.5 w-3.5" />
|
||||
}
|
||||
<span className="text-xs font-normal">{t('operation.stopResponding', { ns: 'appDebug' })}</span>
|
||||
</Button>
|
||||
@ -650,15 +120,15 @@ const Result: FC<IResultProps> = ({
|
||||
)}
|
||||
<TextGenerationRes
|
||||
isWorkflow={isWorkflow}
|
||||
workflowProcessData={workflowProcessData}
|
||||
workflowProcessData={runState.workflowProcessData}
|
||||
isError={isError}
|
||||
onRetry={handleSend}
|
||||
content={completionRes}
|
||||
messageId={messageId}
|
||||
content={runState.completionRes}
|
||||
messageId={runState.messageId}
|
||||
isInWebApp
|
||||
moreLikeThis={moreLikeThisEnabled}
|
||||
onFeedback={handleFeedback}
|
||||
feedback={feedback}
|
||||
onFeedback={runState.handleFeedback}
|
||||
feedback={runState.feedback}
|
||||
onSave={handleSaveMessage}
|
||||
isMobile={isMobile}
|
||||
appSourceType={appSourceType}
|
||||
@ -666,7 +136,7 @@ const Result: FC<IResultProps> = ({
|
||||
// isLoading={isCallBatchAPI ? (!completionRes && isResponding) : false}
|
||||
isLoading={false}
|
||||
taskId={isCallBatchAPI ? ((taskId as number) < 10 ? `0${taskId}` : `${taskId}`) : undefined}
|
||||
controlClearMoreLikeThis={controlClearMoreLikeThis}
|
||||
controlClearMoreLikeThis={runState.controlClearMoreLikeThis}
|
||||
isShowTextToSpeech={isShowTextToSpeech}
|
||||
hideProcessDetail
|
||||
siteInfo={siteInfo}
|
||||
@ -677,7 +147,7 @@ const Result: FC<IResultProps> = ({
|
||||
return (
|
||||
<>
|
||||
{!isCallBatchAPI && !isWorkflow && (
|
||||
(isResponding && !completionRes)
|
||||
(runState.isResponding && !runState.completionRes)
|
||||
? (
|
||||
<div className="flex h-full w-full items-center justify-center">
|
||||
<Loading type="area" />
|
||||
@ -692,13 +162,13 @@ const Result: FC<IResultProps> = ({
|
||||
)
|
||||
)}
|
||||
{!isCallBatchAPI && isWorkflow && (
|
||||
(isResponding && !workflowProcessData)
|
||||
(runState.isResponding && !runState.workflowProcessData)
|
||||
? (
|
||||
<div className="flex h-full w-full items-center justify-center">
|
||||
<Loading type="area" />
|
||||
</div>
|
||||
)
|
||||
: !workflowProcessData
|
||||
: !runState.workflowProcessData
|
||||
? <NoData />
|
||||
: renderTextGenerationRes()
|
||||
)}
|
||||
|
||||
@ -0,0 +1,156 @@
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import type { PromptConfig } from '@/models/debug'
|
||||
import type { VisionFile, VisionSettings } from '@/types/app'
|
||||
import { getProcessedFiles } from '@/app/components/base/file-uploader/utils'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { formatBooleanInputs } from '@/utils/model-config'
|
||||
|
||||
export type ResultInputValue
|
||||
= | string
|
||||
| boolean
|
||||
| number
|
||||
| string[]
|
||||
| Record<string, unknown>
|
||||
| FileEntity
|
||||
| FileEntity[]
|
||||
| undefined
|
||||
|
||||
type Translate = (key: string, options?: Record<string, unknown>) => string
|
||||
|
||||
type ValidationResult = {
|
||||
canSend: boolean
|
||||
notification?: {
|
||||
type: 'error' | 'info'
|
||||
message: string
|
||||
}
|
||||
}
|
||||
|
||||
type ValidateResultRequestParams = {
|
||||
completionFiles: VisionFile[]
|
||||
inputs: Record<string, ResultInputValue>
|
||||
isCallBatchAPI: boolean
|
||||
promptConfig: PromptConfig | null
|
||||
t: Translate
|
||||
}
|
||||
|
||||
type BuildResultRequestDataParams = {
|
||||
completionFiles: VisionFile[]
|
||||
inputs: Record<string, ResultInputValue>
|
||||
promptConfig: PromptConfig | null
|
||||
visionConfig: VisionSettings
|
||||
}
|
||||
|
||||
const isMissingRequiredInput = (
|
||||
variable: PromptConfig['prompt_variables'][number],
|
||||
value: ResultInputValue,
|
||||
) => {
|
||||
if (value === undefined || value === null)
|
||||
return true
|
||||
|
||||
if (variable.type === 'file-list')
|
||||
return !Array.isArray(value) || value.length === 0
|
||||
|
||||
if (['string', 'paragraph', 'number', 'json_object', 'select'].includes(variable.type))
|
||||
return typeof value !== 'string' ? false : value.trim() === ''
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
const hasPendingLocalFiles = (completionFiles: VisionFile[]) => {
|
||||
return completionFiles.some(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)
|
||||
}
|
||||
|
||||
export const validateResultRequest = ({
|
||||
completionFiles,
|
||||
inputs,
|
||||
isCallBatchAPI,
|
||||
promptConfig,
|
||||
t,
|
||||
}: ValidateResultRequestParams): ValidationResult => {
|
||||
if (isCallBatchAPI)
|
||||
return { canSend: true }
|
||||
|
||||
const promptVariables = promptConfig?.prompt_variables
|
||||
if (!promptVariables?.length) {
|
||||
if (hasPendingLocalFiles(completionFiles)) {
|
||||
return {
|
||||
canSend: false,
|
||||
notification: {
|
||||
type: 'info',
|
||||
message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return { canSend: true }
|
||||
}
|
||||
|
||||
const requiredVariables = promptVariables.filter(({ key, name, required, type }) => {
|
||||
if (type === 'boolean' || type === 'checkbox')
|
||||
return false
|
||||
|
||||
return (!key || !key.trim()) || (!name || !name.trim()) || required === undefined || required === null || required
|
||||
})
|
||||
|
||||
const missingRequiredVariable = requiredVariables.find(variable => isMissingRequiredInput(variable, inputs[variable.key]))?.name
|
||||
if (missingRequiredVariable) {
|
||||
return {
|
||||
canSend: false,
|
||||
notification: {
|
||||
type: 'error',
|
||||
message: t('errorMessage.valueOfVarRequired', {
|
||||
ns: 'appDebug',
|
||||
key: missingRequiredVariable,
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if (hasPendingLocalFiles(completionFiles)) {
|
||||
return {
|
||||
canSend: false,
|
||||
notification: {
|
||||
type: 'info',
|
||||
message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return { canSend: true }
|
||||
}
|
||||
|
||||
export const buildResultRequestData = ({
|
||||
completionFiles,
|
||||
inputs,
|
||||
promptConfig,
|
||||
visionConfig,
|
||||
}: BuildResultRequestDataParams) => {
|
||||
const processedInputs = {
|
||||
...formatBooleanInputs(promptConfig?.prompt_variables, inputs as Record<string, string | number | boolean | object>),
|
||||
}
|
||||
|
||||
promptConfig?.prompt_variables.forEach((variable) => {
|
||||
const value = processedInputs[variable.key]
|
||||
if (variable.type === 'file' && value && typeof value === 'object' && !Array.isArray(value)) {
|
||||
processedInputs[variable.key] = getProcessedFiles([value as FileEntity])[0]
|
||||
return
|
||||
}
|
||||
|
||||
if (variable.type === 'file-list' && Array.isArray(value) && value.length > 0)
|
||||
processedInputs[variable.key] = getProcessedFiles(value as FileEntity[])
|
||||
})
|
||||
|
||||
return {
|
||||
inputs: processedInputs,
|
||||
...(visionConfig.enabled && completionFiles.length > 0
|
||||
? {
|
||||
files: completionFiles.map((item) => {
|
||||
if (item.transfer_method === TransferMethod.local_file)
|
||||
return { ...item, url: '' }
|
||||
|
||||
return item
|
||||
}),
|
||||
}
|
||||
: {}),
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,404 @@
|
||||
import type { Dispatch, SetStateAction } from 'react'
|
||||
import type { WorkflowProcess } from '@/app/components/base/chat/types'
|
||||
import type { IOtherOptions } from '@/service/base'
|
||||
import type { HumanInputFormTimeoutData, NodeTracing, WorkflowFinishedResponse } from '@/types/workflow'
|
||||
import { produce } from 'immer'
|
||||
import { getFilesInLogs } from '@/app/components/base/file-uploader/utils'
|
||||
import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types'
|
||||
import { sseGet } from '@/service/base'
|
||||
|
||||
type Notify = (payload: { type: 'error' | 'warning', message: string }) => void
|
||||
type Translate = (key: string, options?: Record<string, unknown>) => string
|
||||
|
||||
type CreateWorkflowStreamHandlersParams = {
|
||||
getCompletionRes: () => string
|
||||
getWorkflowProcessData: () => WorkflowProcess | undefined
|
||||
isTimedOut: () => boolean
|
||||
markEnded: () => void
|
||||
notify: Notify
|
||||
onCompleted: (completionRes: string, taskId?: number, success?: boolean) => void
|
||||
resetRunState: () => void
|
||||
setCompletionRes: (res: string) => void
|
||||
setCurrentTaskId: Dispatch<SetStateAction<string | null>>
|
||||
setIsStopping: Dispatch<SetStateAction<boolean>>
|
||||
setMessageId: Dispatch<SetStateAction<string | null>>
|
||||
setRespondingFalse: () => void
|
||||
setWorkflowProcessData: (data: WorkflowProcess | undefined) => void
|
||||
t: Translate
|
||||
taskId?: number
|
||||
}
|
||||
|
||||
const createInitialWorkflowProcess = (): WorkflowProcess => ({
|
||||
status: WorkflowRunningStatus.Running,
|
||||
tracing: [],
|
||||
expand: false,
|
||||
resultText: '',
|
||||
})
|
||||
|
||||
const updateWorkflowProcess = (
|
||||
current: WorkflowProcess | undefined,
|
||||
updater: (draft: WorkflowProcess) => void,
|
||||
) => {
|
||||
return produce(current ?? createInitialWorkflowProcess(), updater)
|
||||
}
|
||||
|
||||
const matchParallelTrace = (trace: WorkflowProcess['tracing'][number], data: NodeTracing) => {
|
||||
return trace.node_id === data.node_id
|
||||
&& (trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id
|
||||
|| trace.parallel_id === data.execution_metadata?.parallel_id)
|
||||
}
|
||||
|
||||
const ensureParallelTraceDetails = (details?: NodeTracing['details']) => {
|
||||
return details?.length ? details : [[]]
|
||||
}
|
||||
|
||||
const appendParallelStart = (current: WorkflowProcess | undefined, data: NodeTracing) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
draft.expand = true
|
||||
draft.tracing.push({
|
||||
...data,
|
||||
details: ensureParallelTraceDetails(data.details),
|
||||
status: NodeRunningStatus.Running,
|
||||
expand: true,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
const appendParallelNext = (current: WorkflowProcess | undefined, data: NodeTracing) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
draft.expand = true
|
||||
const trace = draft.tracing.find(item => matchParallelTrace(item, data))
|
||||
if (!trace)
|
||||
return
|
||||
|
||||
trace.details = ensureParallelTraceDetails(trace.details)
|
||||
trace.details.push([])
|
||||
})
|
||||
}
|
||||
|
||||
const finishParallelTrace = (current: WorkflowProcess | undefined, data: NodeTracing) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
draft.expand = true
|
||||
const traceIndex = draft.tracing.findIndex(item => matchParallelTrace(item, data))
|
||||
if (traceIndex > -1) {
|
||||
draft.tracing[traceIndex] = {
|
||||
...data,
|
||||
expand: !!data.error,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const upsertWorkflowNode = (current: WorkflowProcess | undefined, data: NodeTracing) => {
|
||||
if (data.iteration_id || data.loop_id)
|
||||
return current
|
||||
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
draft.expand = true
|
||||
const currentIndex = draft.tracing.findIndex(item => item.node_id === data.node_id)
|
||||
const nextTrace = {
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
expand: true,
|
||||
}
|
||||
|
||||
if (currentIndex > -1)
|
||||
draft.tracing[currentIndex] = nextTrace
|
||||
else
|
||||
draft.tracing.push(nextTrace)
|
||||
})
|
||||
}
|
||||
|
||||
const finishWorkflowNode = (current: WorkflowProcess | undefined, data: NodeTracing) => {
|
||||
if (data.iteration_id || data.loop_id)
|
||||
return current
|
||||
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
const currentIndex = draft.tracing.findIndex(trace => matchParallelTrace(trace, data))
|
||||
if (currentIndex > -1) {
|
||||
draft.tracing[currentIndex] = {
|
||||
...(draft.tracing[currentIndex].extras
|
||||
? { extras: draft.tracing[currentIndex].extras }
|
||||
: {}),
|
||||
...data,
|
||||
expand: !!data.error,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const markNodesStopped = (traces?: WorkflowProcess['tracing']) => {
|
||||
if (!traces)
|
||||
return
|
||||
|
||||
const markTrace = (trace: WorkflowProcess['tracing'][number]) => {
|
||||
if ([NodeRunningStatus.Running, NodeRunningStatus.Waiting].includes(trace.status as NodeRunningStatus))
|
||||
trace.status = NodeRunningStatus.Stopped
|
||||
|
||||
trace.details?.forEach(detailGroup => detailGroup.forEach(markTrace))
|
||||
trace.retryDetail?.forEach(markTrace)
|
||||
trace.parallelDetail?.children?.forEach(markTrace)
|
||||
}
|
||||
|
||||
traces.forEach(markTrace)
|
||||
}
|
||||
|
||||
const applyWorkflowFinishedState = (
|
||||
current: WorkflowProcess | undefined,
|
||||
status: WorkflowRunningStatus,
|
||||
) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
draft.status = status
|
||||
if ([WorkflowRunningStatus.Stopped, WorkflowRunningStatus.Failed].includes(status))
|
||||
markNodesStopped(draft.tracing)
|
||||
})
|
||||
}
|
||||
|
||||
const applyWorkflowOutputs = (
|
||||
current: WorkflowProcess | undefined,
|
||||
outputs: WorkflowFinishedResponse['data']['outputs'],
|
||||
) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
draft.status = WorkflowRunningStatus.Succeeded
|
||||
draft.files = getFilesInLogs(outputs || []) as unknown as WorkflowProcess['files']
|
||||
})
|
||||
}
|
||||
|
||||
const appendResultText = (current: WorkflowProcess | undefined, text: string) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
draft.resultText = `${draft.resultText || ''}${text}`
|
||||
})
|
||||
}
|
||||
|
||||
const replaceResultText = (current: WorkflowProcess | undefined, text: string) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
draft.resultText = text
|
||||
})
|
||||
}
|
||||
|
||||
const updateHumanInputRequired = (
|
||||
current: WorkflowProcess | undefined,
|
||||
data: NonNullable<WorkflowProcess['humanInputFormDataList']>[number],
|
||||
) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
if (!draft.humanInputFormDataList) {
|
||||
draft.humanInputFormDataList = [data]
|
||||
}
|
||||
else {
|
||||
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === data.node_id)
|
||||
if (currentFormIndex > -1)
|
||||
draft.humanInputFormDataList[currentFormIndex] = data
|
||||
else
|
||||
draft.humanInputFormDataList.push(data)
|
||||
}
|
||||
|
||||
const currentIndex = draft.tracing.findIndex(item => item.node_id === data.node_id)
|
||||
if (currentIndex > -1)
|
||||
draft.tracing[currentIndex].status = NodeRunningStatus.Paused
|
||||
})
|
||||
}
|
||||
|
||||
const updateHumanInputFilled = (
|
||||
current: WorkflowProcess | undefined,
|
||||
data: NonNullable<WorkflowProcess['humanInputFilledFormDataList']>[number],
|
||||
) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
if (draft.humanInputFormDataList?.length) {
|
||||
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === data.node_id)
|
||||
if (currentFormIndex > -1)
|
||||
draft.humanInputFormDataList.splice(currentFormIndex, 1)
|
||||
}
|
||||
|
||||
if (!draft.humanInputFilledFormDataList)
|
||||
draft.humanInputFilledFormDataList = [data]
|
||||
else
|
||||
draft.humanInputFilledFormDataList.push(data)
|
||||
})
|
||||
}
|
||||
|
||||
const updateHumanInputTimeout = (
|
||||
current: WorkflowProcess | undefined,
|
||||
data: HumanInputFormTimeoutData,
|
||||
) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
if (!draft.humanInputFormDataList?.length)
|
||||
return
|
||||
|
||||
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === data.node_id)
|
||||
if (currentFormIndex > -1)
|
||||
draft.humanInputFormDataList[currentFormIndex].expiration_time = data.expiration_time
|
||||
})
|
||||
}
|
||||
|
||||
const applyWorkflowPaused = (current: WorkflowProcess | undefined) => {
|
||||
return updateWorkflowProcess(current, (draft) => {
|
||||
draft.expand = false
|
||||
draft.status = WorkflowRunningStatus.Paused
|
||||
})
|
||||
}
|
||||
|
||||
const serializeWorkflowOutputs = (outputs: WorkflowFinishedResponse['data']['outputs']) => {
|
||||
if (outputs === undefined || outputs === null)
|
||||
return ''
|
||||
|
||||
if (typeof outputs === 'string')
|
||||
return outputs
|
||||
|
||||
try {
|
||||
return JSON.stringify(outputs) ?? ''
|
||||
}
|
||||
catch {
|
||||
return String(outputs)
|
||||
}
|
||||
}
|
||||
|
||||
export const createWorkflowStreamHandlers = ({
|
||||
getCompletionRes,
|
||||
getWorkflowProcessData,
|
||||
isTimedOut,
|
||||
markEnded,
|
||||
notify,
|
||||
onCompleted,
|
||||
resetRunState,
|
||||
setCompletionRes,
|
||||
setCurrentTaskId,
|
||||
setIsStopping,
|
||||
setMessageId,
|
||||
setRespondingFalse,
|
||||
setWorkflowProcessData,
|
||||
t,
|
||||
taskId,
|
||||
}: CreateWorkflowStreamHandlersParams): IOtherOptions => {
|
||||
let tempMessageId = ''
|
||||
|
||||
const finishWithFailure = () => {
|
||||
setRespondingFalse()
|
||||
resetRunState()
|
||||
onCompleted(getCompletionRes(), taskId, false)
|
||||
markEnded()
|
||||
}
|
||||
|
||||
const finishWithSuccess = () => {
|
||||
setRespondingFalse()
|
||||
resetRunState()
|
||||
setMessageId(tempMessageId)
|
||||
onCompleted(getCompletionRes(), taskId, true)
|
||||
markEnded()
|
||||
}
|
||||
|
||||
const otherOptions: IOtherOptions = {
|
||||
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
|
||||
const workflowProcessData = getWorkflowProcessData()
|
||||
if (workflowProcessData?.tracing.length) {
|
||||
setWorkflowProcessData(updateWorkflowProcess(workflowProcessData, (draft) => {
|
||||
draft.expand = true
|
||||
draft.status = WorkflowRunningStatus.Running
|
||||
}))
|
||||
return
|
||||
}
|
||||
|
||||
tempMessageId = workflow_run_id
|
||||
setCurrentTaskId(task_id || null)
|
||||
setIsStopping(false)
|
||||
setWorkflowProcessData(createInitialWorkflowProcess())
|
||||
},
|
||||
onIterationStart: ({ data }) => {
|
||||
setWorkflowProcessData(appendParallelStart(getWorkflowProcessData(), data))
|
||||
},
|
||||
onIterationNext: ({ data }) => {
|
||||
setWorkflowProcessData(appendParallelNext(getWorkflowProcessData(), data))
|
||||
},
|
||||
onIterationFinish: ({ data }) => {
|
||||
setWorkflowProcessData(finishParallelTrace(getWorkflowProcessData(), data))
|
||||
},
|
||||
onLoopStart: ({ data }) => {
|
||||
setWorkflowProcessData(appendParallelStart(getWorkflowProcessData(), data))
|
||||
},
|
||||
onLoopNext: ({ data }) => {
|
||||
setWorkflowProcessData(appendParallelNext(getWorkflowProcessData(), data))
|
||||
},
|
||||
onLoopFinish: ({ data }) => {
|
||||
setWorkflowProcessData(finishParallelTrace(getWorkflowProcessData(), data))
|
||||
},
|
||||
onNodeStarted: ({ data }) => {
|
||||
setWorkflowProcessData(upsertWorkflowNode(getWorkflowProcessData(), data))
|
||||
},
|
||||
onNodeFinished: ({ data }) => {
|
||||
setWorkflowProcessData(finishWorkflowNode(getWorkflowProcessData(), data))
|
||||
},
|
||||
onWorkflowFinished: ({ data }) => {
|
||||
if (isTimedOut()) {
|
||||
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
|
||||
return
|
||||
}
|
||||
|
||||
const workflowStatus = data.status as WorkflowRunningStatus | undefined
|
||||
if (workflowStatus === WorkflowRunningStatus.Stopped) {
|
||||
setWorkflowProcessData(applyWorkflowFinishedState(getWorkflowProcessData(), WorkflowRunningStatus.Stopped))
|
||||
finishWithFailure()
|
||||
return
|
||||
}
|
||||
|
||||
if (data.error) {
|
||||
notify({ type: 'error', message: data.error })
|
||||
setWorkflowProcessData(applyWorkflowFinishedState(getWorkflowProcessData(), WorkflowRunningStatus.Failed))
|
||||
finishWithFailure()
|
||||
return
|
||||
}
|
||||
|
||||
setWorkflowProcessData(applyWorkflowOutputs(getWorkflowProcessData(), data.outputs))
|
||||
const serializedOutputs = serializeWorkflowOutputs(data.outputs)
|
||||
setCompletionRes(serializedOutputs)
|
||||
if (data.outputs) {
|
||||
const outputKeys = Object.keys(data.outputs)
|
||||
const isStringOutput = outputKeys.length === 1 && typeof data.outputs[outputKeys[0]] === 'string'
|
||||
if (isStringOutput) {
|
||||
setWorkflowProcessData(updateWorkflowProcess(getWorkflowProcessData(), (draft) => {
|
||||
draft.resultText = data.outputs[outputKeys[0]]
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
finishWithSuccess()
|
||||
},
|
||||
onTextChunk: ({ data: { text } }) => {
|
||||
setWorkflowProcessData(appendResultText(getWorkflowProcessData(), text))
|
||||
},
|
||||
onTextReplace: ({ data: { text } }) => {
|
||||
setWorkflowProcessData(replaceResultText(getWorkflowProcessData(), text))
|
||||
},
|
||||
onHumanInputRequired: ({ data }) => {
|
||||
setWorkflowProcessData(updateHumanInputRequired(getWorkflowProcessData(), data))
|
||||
},
|
||||
onHumanInputFormFilled: ({ data }) => {
|
||||
setWorkflowProcessData(updateHumanInputFilled(getWorkflowProcessData(), data))
|
||||
},
|
||||
onHumanInputFormTimeout: ({ data }) => {
|
||||
setWorkflowProcessData(updateHumanInputTimeout(getWorkflowProcessData(), data))
|
||||
},
|
||||
onWorkflowPaused: ({ data }) => {
|
||||
tempMessageId = data.workflow_run_id
|
||||
void sseGet(`/workflow/${data.workflow_run_id}/events`, {}, otherOptions)
|
||||
setWorkflowProcessData(applyWorkflowPaused(getWorkflowProcessData()))
|
||||
},
|
||||
}
|
||||
|
||||
return otherOptions
|
||||
}
|
||||
|
||||
export {
|
||||
appendParallelNext,
|
||||
appendParallelStart,
|
||||
appendResultText,
|
||||
applyWorkflowFinishedState,
|
||||
applyWorkflowOutputs,
|
||||
applyWorkflowPaused,
|
||||
finishParallelTrace,
|
||||
finishWorkflowNode,
|
||||
markNodesStopped,
|
||||
replaceResultText,
|
||||
updateHumanInputFilled,
|
||||
updateHumanInputRequired,
|
||||
updateHumanInputTimeout,
|
||||
upsertWorkflowNode,
|
||||
}
|
||||
@ -5646,11 +5646,8 @@
|
||||
}
|
||||
},
|
||||
"app/components/share/text-generation/result/index.tsx": {
|
||||
"react-hooks-extra/no-direct-set-state-in-use-effect": {
|
||||
"count": 3
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 3
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/share/text-generation/run-batch/csv-download/index.tsx": {
|
||||
|
||||
256
web/scripts/check-components-diff-coverage-lib.mjs
Normal file
256
web/scripts/check-components-diff-coverage-lib.mjs
Normal file
@ -0,0 +1,256 @@
|
||||
import fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
|
||||
const DIFF_COVERAGE_IGNORE_LINE_TOKEN = 'diff-coverage-ignore-line:'
|
||||
|
||||
export function parseChangedLineMap(diff, isTrackedComponentSourceFile) {
|
||||
const lineMap = new Map()
|
||||
let currentFile = null
|
||||
|
||||
for (const line of diff.split('\n')) {
|
||||
if (line.startsWith('+++ b/')) {
|
||||
currentFile = line.slice(6).trim()
|
||||
continue
|
||||
}
|
||||
|
||||
if (!currentFile || !isTrackedComponentSourceFile(currentFile))
|
||||
continue
|
||||
|
||||
const match = line.match(/^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@/)
|
||||
if (!match)
|
||||
continue
|
||||
|
||||
const start = Number(match[1])
|
||||
const count = match[2] ? Number(match[2]) : 1
|
||||
if (count === 0)
|
||||
continue
|
||||
|
||||
const linesForFile = lineMap.get(currentFile) ?? new Set()
|
||||
for (let offset = 0; offset < count; offset += 1)
|
||||
linesForFile.add(start + offset)
|
||||
lineMap.set(currentFile, linesForFile)
|
||||
}
|
||||
|
||||
return lineMap
|
||||
}
|
||||
|
||||
export function normalizeToRepoRelative(filePath, {
|
||||
appComponentsCoveragePrefix,
|
||||
appComponentsPrefix,
|
||||
repoRoot,
|
||||
sharedTestPrefix,
|
||||
webRoot,
|
||||
}) {
|
||||
if (!filePath)
|
||||
return ''
|
||||
|
||||
if (filePath.startsWith(appComponentsPrefix) || filePath.startsWith(sharedTestPrefix))
|
||||
return filePath
|
||||
|
||||
if (filePath.startsWith(appComponentsCoveragePrefix))
|
||||
return `web/${filePath}`
|
||||
|
||||
const absolutePath = path.isAbsolute(filePath)
|
||||
? filePath
|
||||
: path.resolve(webRoot, filePath)
|
||||
|
||||
return path.relative(repoRoot, absolutePath).split(path.sep).join('/')
|
||||
}
|
||||
|
||||
export function getLineHits(entry) {
|
||||
if (entry?.l && Object.keys(entry.l).length > 0)
|
||||
return entry.l
|
||||
|
||||
const lineHits = {}
|
||||
for (const [statementId, statement] of Object.entries(entry?.statementMap ?? {})) {
|
||||
const line = statement?.start?.line
|
||||
if (!line)
|
||||
continue
|
||||
|
||||
const hits = entry?.s?.[statementId] ?? 0
|
||||
const previous = lineHits[line]
|
||||
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits)
|
||||
}
|
||||
|
||||
return lineHits
|
||||
}
|
||||
|
||||
export function getChangedStatementCoverage(entry, changedLines) {
|
||||
const normalizedChangedLines = [...(changedLines ?? [])].sort((a, b) => a - b)
|
||||
if (!entry) {
|
||||
return {
|
||||
covered: 0,
|
||||
total: normalizedChangedLines.length,
|
||||
uncoveredLines: normalizedChangedLines,
|
||||
}
|
||||
}
|
||||
|
||||
const uncoveredLines = []
|
||||
let covered = 0
|
||||
let total = 0
|
||||
|
||||
for (const [statementId, statement] of Object.entries(entry.statementMap ?? {})) {
|
||||
if (!rangeIntersectsChangedLines(statement, changedLines))
|
||||
continue
|
||||
|
||||
total += 1
|
||||
const hits = entry.s?.[statementId] ?? 0
|
||||
if (hits > 0) {
|
||||
covered += 1
|
||||
continue
|
||||
}
|
||||
|
||||
uncoveredLines.push(statement.start.line)
|
||||
}
|
||||
|
||||
return {
|
||||
covered,
|
||||
total,
|
||||
uncoveredLines: uncoveredLines.sort((a, b) => a - b),
|
||||
}
|
||||
}
|
||||
|
||||
export function getChangedBranchCoverage(entry, changedLines) {
|
||||
if (!entry) {
|
||||
return {
|
||||
covered: 0,
|
||||
total: 0,
|
||||
uncoveredBranches: [],
|
||||
}
|
||||
}
|
||||
|
||||
const uncoveredBranches = []
|
||||
let covered = 0
|
||||
let total = 0
|
||||
|
||||
for (const [branchId, branch] of Object.entries(entry.branchMap ?? {})) {
|
||||
if (!branchIntersectsChangedLines(branch, changedLines))
|
||||
continue
|
||||
|
||||
const hits = Array.isArray(entry.b?.[branchId]) ? entry.b[branchId] : []
|
||||
const locations = getBranchLocations(branch)
|
||||
const armCount = Math.max(locations.length, hits.length)
|
||||
|
||||
for (let armIndex = 0; armIndex < armCount; armIndex += 1) {
|
||||
total += 1
|
||||
if ((hits[armIndex] ?? 0) > 0) {
|
||||
covered += 1
|
||||
continue
|
||||
}
|
||||
|
||||
const location = locations[armIndex] ?? branch.loc ?? branch
|
||||
uncoveredBranches.push({
|
||||
armIndex,
|
||||
line: getLocationStartLine(location) ?? branch.line ?? 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
uncoveredBranches.sort((a, b) => a.line - b.line || a.armIndex - b.armIndex)
|
||||
return {
|
||||
covered,
|
||||
total,
|
||||
uncoveredBranches,
|
||||
}
|
||||
}
|
||||
|
||||
export function getIgnoredChangedLinesFromFile(filePath, changedLines) {
|
||||
if (!fs.existsSync(filePath))
|
||||
return emptyIgnoreResult(changedLines)
|
||||
|
||||
const sourceCode = fs.readFileSync(filePath, 'utf8')
|
||||
return getIgnoredChangedLinesFromSource(sourceCode, changedLines)
|
||||
}
|
||||
|
||||
export function getIgnoredChangedLinesFromSource(sourceCode, changedLines) {
|
||||
const ignoredLines = new Map()
|
||||
const invalidPragmas = []
|
||||
const changedLineSet = new Set(changedLines ?? [])
|
||||
|
||||
const sourceLines = sourceCode.split('\n')
|
||||
sourceLines.forEach((lineText, index) => {
|
||||
const lineNumber = index + 1
|
||||
const commentIndex = lineText.indexOf('//')
|
||||
if (commentIndex < 0)
|
||||
return
|
||||
|
||||
const tokenIndex = lineText.indexOf(DIFF_COVERAGE_IGNORE_LINE_TOKEN, commentIndex + 2)
|
||||
if (tokenIndex < 0)
|
||||
return
|
||||
|
||||
const reason = lineText.slice(tokenIndex + DIFF_COVERAGE_IGNORE_LINE_TOKEN.length).trim()
|
||||
if (!changedLineSet.has(lineNumber))
|
||||
return
|
||||
|
||||
if (!reason) {
|
||||
invalidPragmas.push({
|
||||
line: lineNumber,
|
||||
reason: 'missing ignore reason',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ignoredLines.set(lineNumber, reason)
|
||||
})
|
||||
|
||||
const effectiveChangedLines = new Set(
|
||||
[...changedLineSet].filter(lineNumber => !ignoredLines.has(lineNumber)),
|
||||
)
|
||||
|
||||
return {
|
||||
effectiveChangedLines,
|
||||
ignoredLines,
|
||||
invalidPragmas,
|
||||
}
|
||||
}
|
||||
|
||||
function emptyIgnoreResult(changedLines = []) {
|
||||
return {
|
||||
effectiveChangedLines: new Set(changedLines),
|
||||
ignoredLines: new Map(),
|
||||
invalidPragmas: [],
|
||||
}
|
||||
}
|
||||
|
||||
function branchIntersectsChangedLines(branch, changedLines) {
|
||||
if (!changedLines || changedLines.size === 0)
|
||||
return false
|
||||
|
||||
if (rangeIntersectsChangedLines(branch.loc, changedLines))
|
||||
return true
|
||||
|
||||
const locations = getBranchLocations(branch)
|
||||
if (locations.some(location => rangeIntersectsChangedLines(location, changedLines)))
|
||||
return true
|
||||
|
||||
return branch.line ? changedLines.has(branch.line) : false
|
||||
}
|
||||
|
||||
function getBranchLocations(branch) {
|
||||
return Array.isArray(branch?.locations) ? branch.locations.filter(Boolean) : []
|
||||
}
|
||||
|
||||
function rangeIntersectsChangedLines(location, changedLines) {
|
||||
if (!location || !changedLines || changedLines.size === 0)
|
||||
return false
|
||||
|
||||
const startLine = getLocationStartLine(location)
|
||||
const endLine = getLocationEndLine(location) ?? startLine
|
||||
if (!startLine || !endLine)
|
||||
return false
|
||||
|
||||
for (const lineNumber of changedLines) {
|
||||
if (lineNumber >= startLine && lineNumber <= endLine)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
function getLocationStartLine(location) {
|
||||
return location?.start?.line ?? location?.line ?? null
|
||||
}
|
||||
|
||||
function getLocationEndLine(location) {
|
||||
return location?.end?.line ?? location?.line ?? null
|
||||
}
|
||||
@ -1,6 +1,14 @@
|
||||
import { execFileSync } from 'node:child_process'
|
||||
import fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
import {
|
||||
getChangedBranchCoverage,
|
||||
getChangedStatementCoverage,
|
||||
getIgnoredChangedLinesFromFile,
|
||||
getLineHits,
|
||||
normalizeToRepoRelative,
|
||||
parseChangedLineMap,
|
||||
} from './check-components-diff-coverage-lib.mjs'
|
||||
import {
|
||||
collectComponentCoverageExcludedFiles,
|
||||
COMPONENT_COVERAGE_EXCLUDE_LABEL,
|
||||
@ -54,7 +62,13 @@ if (changedSourceFiles.length === 0) {
|
||||
|
||||
const coverageEntries = new Map()
|
||||
for (const [file, entry] of Object.entries(coverage)) {
|
||||
const repoRelativePath = normalizeToRepoRelative(entry.path ?? file)
|
||||
const repoRelativePath = normalizeToRepoRelative(entry.path ?? file, {
|
||||
appComponentsCoveragePrefix: APP_COMPONENTS_COVERAGE_PREFIX,
|
||||
appComponentsPrefix: APP_COMPONENTS_PREFIX,
|
||||
repoRoot,
|
||||
sharedTestPrefix: SHARED_TEST_PREFIX,
|
||||
webRoot,
|
||||
})
|
||||
if (!isTrackedComponentSourceFile(repoRelativePath))
|
||||
continue
|
||||
|
||||
@ -74,46 +88,53 @@ for (const [file, entry] of coverageEntries.entries()) {
|
||||
const overallCoverage = sumCoverageStats(fileCoverageRows)
|
||||
const diffChanges = getChangedLineMap(baseSha, headSha)
|
||||
const diffRows = []
|
||||
const ignoredDiffLines = []
|
||||
const invalidIgnorePragmas = []
|
||||
|
||||
for (const [file, changedLines] of diffChanges.entries()) {
|
||||
if (!isTrackedComponentSourceFile(file))
|
||||
continue
|
||||
|
||||
const entry = coverageEntries.get(file)
|
||||
const lineHits = entry ? getLineHits(entry) : {}
|
||||
const executableChangedLines = [...changedLines]
|
||||
.filter(line => !entry || lineHits[line] !== undefined)
|
||||
.sort((a, b) => a - b)
|
||||
|
||||
if (executableChangedLines.length === 0) {
|
||||
diffRows.push({
|
||||
const ignoreInfo = getIgnoredChangedLinesFromFile(path.join(repoRoot, file), changedLines)
|
||||
for (const [line, reason] of ignoreInfo.ignoredLines.entries()) {
|
||||
ignoredDiffLines.push({
|
||||
file,
|
||||
moduleName: getModuleName(file),
|
||||
total: 0,
|
||||
covered: 0,
|
||||
uncoveredLines: [],
|
||||
line,
|
||||
reason,
|
||||
})
|
||||
}
|
||||
for (const invalidPragma of ignoreInfo.invalidPragmas) {
|
||||
invalidIgnorePragmas.push({
|
||||
file,
|
||||
...invalidPragma,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
const uncoveredLines = executableChangedLines.filter(line => (lineHits[line] ?? 0) === 0)
|
||||
const statements = getChangedStatementCoverage(entry, ignoreInfo.effectiveChangedLines)
|
||||
const branches = getChangedBranchCoverage(entry, ignoreInfo.effectiveChangedLines)
|
||||
diffRows.push({
|
||||
branches,
|
||||
file,
|
||||
ignoredLineCount: ignoreInfo.ignoredLines.size,
|
||||
moduleName: getModuleName(file),
|
||||
total: executableChangedLines.length,
|
||||
covered: executableChangedLines.length - uncoveredLines.length,
|
||||
uncoveredLines,
|
||||
statements,
|
||||
})
|
||||
}
|
||||
|
||||
const diffTotals = diffRows.reduce((acc, row) => {
|
||||
acc.total += row.total
|
||||
acc.covered += row.covered
|
||||
acc.statements.total += row.statements.total
|
||||
acc.statements.covered += row.statements.covered
|
||||
acc.branches.total += row.branches.total
|
||||
acc.branches.covered += row.branches.covered
|
||||
return acc
|
||||
}, { total: 0, covered: 0 })
|
||||
}, {
|
||||
branches: { total: 0, covered: 0 },
|
||||
statements: { total: 0, covered: 0 },
|
||||
})
|
||||
|
||||
const diffCoveragePct = percentage(diffTotals.covered, diffTotals.total)
|
||||
const diffFailures = diffRows.filter(row => row.uncoveredLines.length > 0)
|
||||
const diffStatementFailures = diffRows.filter(row => row.statements.uncoveredLines.length > 0)
|
||||
const diffBranchFailures = diffRows.filter(row => row.branches.uncoveredBranches.length > 0)
|
||||
const overallThresholdFailures = getThresholdFailures(overallCoverage, COMPONENTS_GLOBAL_THRESHOLDS)
|
||||
const moduleCoverageRows = [...moduleCoverageMap.entries()]
|
||||
.map(([moduleName, stats]) => ({
|
||||
@ -139,25 +160,38 @@ appendSummary(buildSummary({
|
||||
overallThresholdFailures,
|
||||
moduleCoverageRows,
|
||||
moduleThresholdFailures,
|
||||
diffBranchFailures,
|
||||
diffRows,
|
||||
diffFailures,
|
||||
diffCoveragePct,
|
||||
diffStatementFailures,
|
||||
diffTotals,
|
||||
changedSourceFiles,
|
||||
changedTestFiles,
|
||||
ignoredDiffLines,
|
||||
invalidIgnorePragmas,
|
||||
missingTestTouch,
|
||||
}))
|
||||
|
||||
if (diffFailures.length > 0 && process.env.CI) {
|
||||
for (const failure of diffFailures.slice(0, 20)) {
|
||||
const firstLine = failure.uncoveredLines[0] ?? 1
|
||||
console.log(`::error file=${failure.file},line=${firstLine}::Uncovered changed lines: ${formatLineRanges(failure.uncoveredLines)}`)
|
||||
if (process.env.CI) {
|
||||
for (const failure of diffStatementFailures.slice(0, 20)) {
|
||||
const firstLine = failure.statements.uncoveredLines[0] ?? 1
|
||||
console.log(`::error file=${failure.file},line=${firstLine}::Uncovered changed statements: ${formatLineRanges(failure.statements.uncoveredLines)}`)
|
||||
}
|
||||
for (const failure of diffBranchFailures.slice(0, 20)) {
|
||||
const firstBranch = failure.branches.uncoveredBranches[0]
|
||||
const line = firstBranch?.line ?? 1
|
||||
console.log(`::error file=${failure.file},line=${line}::Uncovered changed branches: ${formatBranchRefs(failure.branches.uncoveredBranches)}`)
|
||||
}
|
||||
for (const invalidPragma of invalidIgnorePragmas.slice(0, 20)) {
|
||||
console.log(`::error file=${invalidPragma.file},line=${invalidPragma.line}::Invalid diff coverage ignore pragma: ${invalidPragma.reason}`)
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
overallThresholdFailures.length > 0
|
||||
|| moduleThresholdFailures.length > 0
|
||||
|| diffFailures.length > 0
|
||||
|| diffStatementFailures.length > 0
|
||||
|| diffBranchFailures.length > 0
|
||||
|| invalidIgnorePragmas.length > 0
|
||||
|| (STRICT_TEST_FILE_TOUCH && missingTestTouch)
|
||||
) {
|
||||
process.exit(1)
|
||||
@ -168,11 +202,14 @@ function buildSummary({
|
||||
overallThresholdFailures,
|
||||
moduleCoverageRows,
|
||||
moduleThresholdFailures,
|
||||
diffBranchFailures,
|
||||
diffRows,
|
||||
diffFailures,
|
||||
diffCoveragePct,
|
||||
diffStatementFailures,
|
||||
diffTotals,
|
||||
changedSourceFiles,
|
||||
changedTestFiles,
|
||||
ignoredDiffLines,
|
||||
invalidIgnorePragmas,
|
||||
missingTestTouch,
|
||||
}) {
|
||||
const lines = [
|
||||
@ -189,7 +226,8 @@ function buildSummary({
|
||||
`| Overall tracked statements | ${formatPercent(overallCoverage.statements)} | ${overallCoverage.statements.covered}/${overallCoverage.statements.total}; threshold ${COMPONENTS_GLOBAL_THRESHOLDS.statements}% |`,
|
||||
`| Overall tracked functions | ${formatPercent(overallCoverage.functions)} | ${overallCoverage.functions.covered}/${overallCoverage.functions.total}; threshold ${COMPONENTS_GLOBAL_THRESHOLDS.functions}% |`,
|
||||
`| Overall tracked branches | ${formatPercent(overallCoverage.branches)} | ${overallCoverage.branches.covered}/${overallCoverage.branches.total}; threshold ${COMPONENTS_GLOBAL_THRESHOLDS.branches}% |`,
|
||||
`| Changed executable lines | ${formatPercent({ covered: diffTotals.covered, total: diffTotals.total })} | ${diffTotals.covered}/${diffTotals.total} |`,
|
||||
`| Changed statements | ${formatDiffPercent(diffTotals.statements)} | ${diffTotals.statements.covered}/${diffTotals.statements.total} |`,
|
||||
`| Changed branches | ${formatDiffPercent(diffTotals.branches)} | ${diffTotals.branches.covered}/${diffTotals.branches.total} |`,
|
||||
'',
|
||||
]
|
||||
|
||||
@ -239,20 +277,19 @@ function buildSummary({
|
||||
lines.push('')
|
||||
|
||||
const changedRows = diffRows
|
||||
.filter(row => row.total > 0)
|
||||
.filter(row => row.statements.total > 0 || row.branches.total > 0)
|
||||
.sort((a, b) => {
|
||||
const aPct = percentage(rowCovered(a), rowTotal(a))
|
||||
const bPct = percentage(rowCovered(b), rowTotal(b))
|
||||
return aPct - bPct || a.file.localeCompare(b.file)
|
||||
const aScore = percentage(a.statements.covered + a.branches.covered, a.statements.total + a.branches.total)
|
||||
const bScore = percentage(b.statements.covered + b.branches.covered, b.statements.total + b.branches.total)
|
||||
return aScore - bScore || a.file.localeCompare(b.file)
|
||||
})
|
||||
|
||||
lines.push('<details><summary>Changed file coverage</summary>')
|
||||
lines.push('')
|
||||
lines.push('| File | Module | Changed executable lines | Coverage | Uncovered lines |')
|
||||
lines.push('|---|---|---:|---:|---|')
|
||||
lines.push('| File | Module | Changed statements | Statement coverage | Uncovered statements | Changed branches | Branch coverage | Uncovered branches | Ignored lines |')
|
||||
lines.push('|---|---|---:|---:|---|---:|---:|---|---:|')
|
||||
for (const row of changedRows) {
|
||||
const rowPct = percentage(row.covered, row.total)
|
||||
lines.push(`| ${row.file.replace('web/', '')} | ${row.moduleName} | ${row.total} | ${rowPct.toFixed(2)}% | ${formatLineRanges(row.uncoveredLines)} |`)
|
||||
lines.push(`| ${row.file.replace('web/', '')} | ${row.moduleName} | ${row.statements.total} | ${formatDiffPercent(row.statements)} | ${formatLineRanges(row.statements.uncoveredLines)} | ${row.branches.total} | ${formatDiffPercent(row.branches)} | ${formatBranchRefs(row.branches.uncoveredBranches)} | ${row.ignoredLineCount} |`)
|
||||
}
|
||||
lines.push('</details>')
|
||||
lines.push('')
|
||||
@ -268,16 +305,41 @@ function buildSummary({
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (diffFailures.length > 0) {
|
||||
lines.push('Uncovered changed lines:')
|
||||
for (const row of diffFailures) {
|
||||
lines.push(`- ${row.file.replace('web/', '')}: ${formatLineRanges(row.uncoveredLines)}`)
|
||||
if (diffStatementFailures.length > 0) {
|
||||
lines.push('Uncovered changed statements:')
|
||||
for (const row of diffStatementFailures) {
|
||||
lines.push(`- ${row.file.replace('web/', '')}: ${formatLineRanges(row.statements.uncoveredLines)}`)
|
||||
}
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (diffBranchFailures.length > 0) {
|
||||
lines.push('Uncovered changed branches:')
|
||||
for (const row of diffBranchFailures) {
|
||||
lines.push(`- ${row.file.replace('web/', '')}: ${formatBranchRefs(row.branches.uncoveredBranches)}`)
|
||||
}
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (ignoredDiffLines.length > 0) {
|
||||
lines.push('Ignored changed lines via pragma:')
|
||||
for (const ignoredLine of ignoredDiffLines) {
|
||||
lines.push(`- ${ignoredLine.file.replace('web/', '')}:${ignoredLine.line} - ${ignoredLine.reason}`)
|
||||
}
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (invalidIgnorePragmas.length > 0) {
|
||||
lines.push('Invalid diff coverage ignore pragmas:')
|
||||
for (const invalidPragma of invalidIgnorePragmas) {
|
||||
lines.push(`- ${invalidPragma.file.replace('web/', '')}:${invalidPragma.line} - ${invalidPragma.reason}`)
|
||||
}
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
lines.push(`Changed source files checked: ${changedSourceFiles.length}`)
|
||||
lines.push(`Changed executable line coverage: ${diffCoveragePct.toFixed(2)}%`)
|
||||
lines.push(`Changed statement coverage: ${percentage(diffTotals.statements.covered, diffTotals.statements.total).toFixed(2)}%`)
|
||||
lines.push(`Changed branch coverage: ${percentage(diffTotals.branches.covered, diffTotals.branches.total).toFixed(2)}%`)
|
||||
|
||||
return lines
|
||||
}
|
||||
@ -312,34 +374,7 @@ function getChangedFiles(base, head) {
|
||||
|
||||
function getChangedLineMap(base, head) {
|
||||
const diff = execGit(['diff', '--unified=0', '--no-color', '--diff-filter=ACMR', `${base}...${head}`, '--', 'web/app/components'])
|
||||
const lineMap = new Map()
|
||||
let currentFile = null
|
||||
|
||||
for (const line of diff.split('\n')) {
|
||||
if (line.startsWith('+++ b/')) {
|
||||
currentFile = line.slice(6).trim()
|
||||
continue
|
||||
}
|
||||
|
||||
if (!currentFile || !isTrackedComponentSourceFile(currentFile))
|
||||
continue
|
||||
|
||||
const match = line.match(/^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@/)
|
||||
if (!match)
|
||||
continue
|
||||
|
||||
const start = Number(match[1])
|
||||
const count = match[2] ? Number(match[2]) : 1
|
||||
if (count === 0)
|
||||
continue
|
||||
|
||||
const linesForFile = lineMap.get(currentFile) ?? new Set()
|
||||
for (let offset = 0; offset < count; offset += 1)
|
||||
linesForFile.add(start + offset)
|
||||
lineMap.set(currentFile, linesForFile)
|
||||
}
|
||||
|
||||
return lineMap
|
||||
return parseChangedLineMap(diff, isTrackedComponentSourceFile)
|
||||
}
|
||||
|
||||
function isAnyComponentSourceFile(filePath) {
|
||||
@ -407,24 +442,6 @@ function getCoverageStats(entry) {
|
||||
}
|
||||
}
|
||||
|
||||
function getLineHits(entry) {
|
||||
if (entry.l && Object.keys(entry.l).length > 0)
|
||||
return entry.l
|
||||
|
||||
const lineHits = {}
|
||||
for (const [statementId, statement] of Object.entries(entry.statementMap ?? {})) {
|
||||
const line = statement?.start?.line
|
||||
if (!line)
|
||||
continue
|
||||
|
||||
const hits = entry.s?.[statementId] ?? 0
|
||||
const previous = lineHits[line]
|
||||
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits)
|
||||
}
|
||||
|
||||
return lineHits
|
||||
}
|
||||
|
||||
function sumCoverageStats(rows) {
|
||||
const total = createEmptyCoverageStats()
|
||||
for (const row of rows)
|
||||
@ -479,23 +496,6 @@ function getModuleName(filePath) {
|
||||
return segments.length === 1 ? '(root)' : segments[0]
|
||||
}
|
||||
|
||||
function normalizeToRepoRelative(filePath) {
|
||||
if (!filePath)
|
||||
return ''
|
||||
|
||||
if (filePath.startsWith(APP_COMPONENTS_PREFIX) || filePath.startsWith(SHARED_TEST_PREFIX))
|
||||
return filePath
|
||||
|
||||
if (filePath.startsWith(APP_COMPONENTS_COVERAGE_PREFIX))
|
||||
return `web/${filePath}`
|
||||
|
||||
const absolutePath = path.isAbsolute(filePath)
|
||||
? filePath
|
||||
: path.resolve(webRoot, filePath)
|
||||
|
||||
return path.relative(repoRoot, absolutePath).split(path.sep).join('/')
|
||||
}
|
||||
|
||||
function formatLineRanges(lines) {
|
||||
if (!lines || lines.length === 0)
|
||||
return ''
|
||||
@ -520,6 +520,13 @@ function formatLineRanges(lines) {
|
||||
return ranges.join(', ')
|
||||
}
|
||||
|
||||
function formatBranchRefs(branches) {
|
||||
if (!branches || branches.length === 0)
|
||||
return ''
|
||||
|
||||
return branches.map(branch => `${branch.line}[${branch.armIndex}]`).join(', ')
|
||||
}
|
||||
|
||||
function percentage(covered, total) {
|
||||
if (total === 0)
|
||||
return 100
|
||||
@ -530,6 +537,13 @@ function formatPercent(metric) {
|
||||
return `${percentage(metric.covered, metric.total).toFixed(2)}%`
|
||||
}
|
||||
|
||||
function formatDiffPercent(metric) {
|
||||
if (metric.total === 0)
|
||||
return 'n/a'
|
||||
|
||||
return `${percentage(metric.covered, metric.total).toFixed(2)}%`
|
||||
}
|
||||
|
||||
function appendSummary(lines) {
|
||||
const content = `${lines.join('\n')}\n`
|
||||
if (process.env.GITHUB_STEP_SUMMARY)
|
||||
@ -550,11 +564,3 @@ function repoRootFromCwd() {
|
||||
encoding: 'utf8',
|
||||
}).trim()
|
||||
}
|
||||
|
||||
function rowCovered(row) {
|
||||
return row.covered
|
||||
}
|
||||
|
||||
function rowTotal(row) {
|
||||
return row.total
|
||||
}
|
||||
|
||||
@ -92,10 +92,10 @@ export const COMPONENT_MODULE_THRESHOLDS = {
|
||||
branches: 90,
|
||||
},
|
||||
'share': {
|
||||
lines: 15,
|
||||
statements: 15,
|
||||
functions: 20,
|
||||
branches: 20,
|
||||
lines: 95,
|
||||
statements: 95,
|
||||
functions: 95,
|
||||
branches: 95,
|
||||
},
|
||||
'signin': {
|
||||
lines: 95,
|
||||
|
||||
Reference in New Issue
Block a user