Merge branch 'fix/auto-activate-credential-on-create' into deploy/dev

This commit is contained in:
Yansong Zhang
2026-03-16 15:37:40 +08:00
38 changed files with 4301 additions and 837 deletions

View File

@ -23,7 +23,7 @@ from dify_graph.variables.types import SegmentType
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import login_required
from libs.login import current_user, login_required
from models import App, AppMode
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
@ -100,6 +100,18 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
}
def _ensure_variable_access(
variable: WorkflowDraftVariable | None,
app_id: str,
variable_id: str,
) -> WorkflowDraftVariable:
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_id or variable.user_id != current_user.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
@ -238,6 +250,7 @@ class WorkflowVariableCollectionApi(Resource):
app_id=app_model.id,
page=args.page,
limit=args.limit,
user_id=current_user.id,
)
return workflow_vars
@ -250,7 +263,7 @@ class WorkflowVariableCollectionApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
draft_var_srv.delete_workflow_variables(app_model.id)
draft_var_srv.delete_user_workflow_variables(app_model.id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@ -287,7 +300,7 @@ class NodeVariableCollectionApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id, user_id=current_user.id)
return node_vars
@ -298,7 +311,7 @@ class NodeVariableCollectionApi(Resource):
def delete(self, app_model: App, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(app_model.id, node_id)
srv.delete_node_variables(app_model.id, node_id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@ -319,11 +332,11 @@ class VariableApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id,
)
return variable
@console_ns.doc("update_variable")
@ -360,11 +373,11 @@ class VariableApi(Resource):
)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id,
)
new_name = args_model.name
raw_value = args_model.value
@ -397,11 +410,11 @@ class VariableApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id,
)
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
@ -427,11 +440,11 @@ class VariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, app_id={app_model.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id,
)
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
@ -447,11 +460,15 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(app_model.id)
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
draft_vars = draft_var_srv.list_node_variables(
app_id=app_model.id,
node_id=node_id,
user_id=current_user.id,
)
return draft_vars
@ -472,7 +489,7 @@ class ConversationVariableCollectionApi(Resource):
if draft_workflow is None:
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=current_user.id)
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)

View File

@ -102,6 +102,7 @@ class RagPipelineVariableCollectionApi(Resource):
app_id=pipeline.id,
page=query.page,
limit=query.limit,
user_id=current_user.id,
)
return workflow_vars
@ -111,7 +112,7 @@ class RagPipelineVariableCollectionApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
draft_var_srv.delete_workflow_variables(pipeline.id)
draft_var_srv.delete_user_workflow_variables(pipeline.id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@ -144,7 +145,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id, user_id=current_user.id)
return node_vars
@ -152,7 +153,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
def delete(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(pipeline.id, node_id)
srv.delete_node_variables(pipeline.id, node_id, user_id=current_user.id)
db.session.commit()
return Response("", 204)
@ -283,11 +284,11 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user.id)
return draft_vars

View File

@ -330,9 +330,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
return self._generate(
workflow=workflow,
@ -413,9 +414,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
return self._generate(
workflow=workflow,

View File

@ -419,11 +419,12 @@ class PipelineGenerator(BaseAppGenerator):
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
return self._generate(
@ -514,11 +515,12 @@ class PipelineGenerator(BaseAppGenerator):
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
return self._generate(

View File

@ -417,11 +417,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
return self._generate(
@ -500,11 +501,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
user_id=user.id,
)
return self._generate(
app_model=app_model,

View File

@ -0,0 +1,69 @@
"""add user_id and switch workflow_draft_variables unique key to user scope
Revision ID: 6b5f9f8b1a2c
Revises: 0ec65df55790
Create Date: 2026-03-04 16:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "6b5f9f8b1a2c"
down_revision = "0ec65df55790"
branch_labels = None
depends_on = None
def _is_pg(conn) -> bool:
return conn.dialect.name == "postgresql"
def upgrade():
conn = op.get_bind()
table_name = "workflow_draft_variables"
with op.batch_alter_table(table_name, schema=None) as batch_op:
batch_op.add_column(sa.Column("user_id", models.types.StringUUID(), nullable=True))
if _is_pg(conn):
with op.get_context().autocommit_block():
op.create_index(
"workflow_draft_variables_app_id_user_id_key",
"workflow_draft_variables",
["app_id", "user_id", "node_id", "name"],
unique=True,
postgresql_concurrently=True,
)
else:
op.create_index(
"workflow_draft_variables_app_id_user_id_key",
"workflow_draft_variables",
["app_id", "user_id", "node_id", "name"],
unique=True,
)
with op.batch_alter_table(table_name, schema=None) as batch_op:
batch_op.drop_constraint(op.f("workflow_draft_variables_app_id_key"), type_="unique")
def downgrade():
conn = op.get_bind()
with op.batch_alter_table("workflow_draft_variables", schema=None) as batch_op:
batch_op.create_unique_constraint(
op.f("workflow_draft_variables_app_id_key"),
["app_id", "node_id", "name"],
)
if _is_pg(conn):
with op.get_context().autocommit_block():
op.drop_index("workflow_draft_variables_app_id_user_id_key", postgresql_concurrently=True)
else:
op.drop_index("workflow_draft_variables_app_id_user_id_key", table_name="workflow_draft_variables")
with op.batch_alter_table("workflow_draft_variables", schema=None) as batch_op:
batch_op.drop_column("user_id")

View File

@ -1286,16 +1286,17 @@ class WorkflowDraftVariable(Base):
"""
@staticmethod
def unique_app_id_node_id_name() -> list[str]:
def unique_app_id_user_id_node_id_name() -> list[str]:
return [
"app_id",
"user_id",
"node_id",
"name",
]
__tablename__ = "workflow_draft_variables"
__table_args__ = (
UniqueConstraint(*unique_app_id_node_id_name()),
UniqueConstraint(*unique_app_id_user_id_node_id_name()),
Index("workflow_draft_variable_file_id_idx", "file_id"),
)
# Required for instance variable annotation.
@ -1321,6 +1322,11 @@ class WorkflowDraftVariable(Base):
# "`app_id` maps to the `id` field in the `model.App` model."
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# Owner of this draft variable.
#
# This field is nullable during migration and will be migrated to NOT NULL
# in a follow-up release.
user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
# `last_edited_at` records when the value of a given draft variable
# is edited.
@ -1573,6 +1579,7 @@ class WorkflowDraftVariable(Base):
cls,
*,
app_id: str,
user_id: str | None,
node_id: str,
name: str,
value: Segment,
@ -1586,6 +1593,7 @@ class WorkflowDraftVariable(Base):
variable.updated_at = naive_utc_now()
variable.description = description
variable.app_id = app_id
variable.user_id = user_id
variable.node_id = node_id
variable.name = name
variable.set_value(value)
@ -1599,12 +1607,14 @@ class WorkflowDraftVariable(Base):
cls,
*,
app_id: str,
user_id: str | None = None,
name: str,
value: Segment,
description: str = "",
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
user_id=user_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
name=name,
value=value,
@ -1619,6 +1629,7 @@ class WorkflowDraftVariable(Base):
cls,
*,
app_id: str,
user_id: str | None = None,
name: str,
value: Segment,
node_execution_id: str,
@ -1626,6 +1637,7 @@ class WorkflowDraftVariable(Base):
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
user_id=user_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
name=name,
node_execution_id=node_execution_id,
@ -1639,6 +1651,7 @@ class WorkflowDraftVariable(Base):
cls,
*,
app_id: str,
user_id: str | None = None,
node_id: str,
name: str,
value: Segment,
@ -1649,6 +1662,7 @@ class WorkflowDraftVariable(Base):
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
user_id=user_id,
node_id=node_id,
name=name,
node_execution_id=node_execution_id,

View File

@ -304,7 +304,7 @@ class AppDslService:
)
draft_var_srv = WorkflowDraftVariableService(session=self._session)
draft_var_srv.delete_workflow_variables(app_id=app.id)
draft_var_srv.delete_app_workflow_variables(app_id=app.id)
return Import(
id=import_id,
status=status,

View File

@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider import Provider, ProviderCredential
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
from models.provider_ids import GenericProviderID
from services.enterprise.plugin_manager_service import (
PluginManagerService,
@ -569,6 +569,13 @@ class PluginService:
)
)
session.execute(
delete(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"),
)
)
logger.info(
"Completed deleting credentials and cleaning provider associations for plugin: %s",
plugin_id,

View File

@ -472,6 +472,7 @@ class RagPipelineService:
engine=db.engine,
app_id=pipeline.id,
tenant_id=pipeline.tenant_id,
user_id=account.id,
),
),
start_at=start_at,
@ -1237,6 +1238,7 @@ class RagPipelineService:
engine=db.engine,
app_id=pipeline.id,
tenant_id=pipeline.tenant_id,
user_id=current_user.id,
),
),
start_at=start_at,

View File

@ -77,6 +77,7 @@ class DraftVarLoader(VariableLoader):
_engine: Engine
# Application ID for which variables are being loaded.
_app_id: str
_user_id: str
_tenant_id: str
_fallback_variables: Sequence[VariableBase]
@ -85,10 +86,12 @@ class DraftVarLoader(VariableLoader):
engine: Engine,
app_id: str,
tenant_id: str,
user_id: str,
fallback_variables: Sequence[VariableBase] | None = None,
):
self._engine = engine
self._app_id = app_id
self._user_id = user_id
self._tenant_id = tenant_id
self._fallback_variables = fallback_variables or []
@ -104,7 +107,7 @@ class DraftVarLoader(VariableLoader):
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors, user_id=self._user_id)
# Important:
files: list[File] = []
@ -218,6 +221,7 @@ class WorkflowDraftVariableService:
self,
app_id: str,
selectors: Sequence[list[str]],
user_id: str,
) -> list[WorkflowDraftVariable]:
"""
Retrieve WorkflowDraftVariable instances based on app_id and selectors.
@ -238,22 +242,30 @@ class WorkflowDraftVariableService:
# Alternatively, a `SELECT` statement could be constructed for each selector and
# combined using `UNION` to fetch all rows.
# Benchmarking indicates that both approaches yield comparable performance.
variables = (
query = (
self._session.query(WorkflowDraftVariable)
.options(
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
WorkflowDraftVariableFile.upload_file
)
)
.where(WorkflowDraftVariable.app_id == app_id, or_(*ors))
.all()
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.user_id == user_id,
or_(*ors),
)
)
return variables
return query.all()
def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList:
criteria = WorkflowDraftVariable.app_id == app_id
def list_variables_without_values(
self, app_id: str, page: int, limit: int, user_id: str
) -> WorkflowDraftVariableList:
criteria = [
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.user_id == user_id,
]
total = None
query = self._session.query(WorkflowDraftVariable).where(criteria)
query = self._session.query(WorkflowDraftVariable).where(*criteria)
if page == 1:
total = query.count()
variables = (
@ -269,11 +281,12 @@ class WorkflowDraftVariableService:
return WorkflowDraftVariableList(variables=variables, total=total)
def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
criteria = (
def _list_node_variables(self, app_id: str, node_id: str, user_id: str) -> WorkflowDraftVariableList:
criteria = [
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
)
WorkflowDraftVariable.user_id == user_id,
]
query = self._session.query(WorkflowDraftVariable).where(*criteria)
variables = (
query.options(orm.selectinload(WorkflowDraftVariable.variable_file))
@ -282,36 +295,36 @@ class WorkflowDraftVariableService:
)
return WorkflowDraftVariableList(variables=variables)
def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, node_id)
def list_node_variables(self, app_id: str, node_id: str, user_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, node_id, user_id=user_id)
def list_conversation_variables(self, app_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID)
def list_conversation_variables(self, app_id: str, user_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID, user_id=user_id)
def list_system_variables(self, app_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID)
def list_system_variables(self, app_id: str, user_id: str) -> WorkflowDraftVariableList:
return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID, user_id=user_id)
def get_conversation_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name)
def get_conversation_variable(self, app_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name, user_id=user_id)
def get_system_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name)
def get_system_variable(self, app_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, user_id=user_id)
def get_node_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id, node_id, name)
def get_node_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
return self._get_variable(app_id, node_id, name, user_id=user_id)
def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
variable = (
def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None:
return (
self._session.query(WorkflowDraftVariable)
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
WorkflowDraftVariable.name == name,
WorkflowDraftVariable.user_id == user_id,
)
.first()
)
return variable
def update_variable(
self,
@ -462,7 +475,17 @@ class WorkflowDraftVariableService:
self._session.delete(upload_file)
self._session.delete(variable)
def delete_workflow_variables(self, app_id: str):
def delete_user_workflow_variables(self, app_id: str, user_id: str):
(
self._session.query(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.user_id == user_id,
)
.delete(synchronize_session=False)
)
def delete_app_workflow_variables(self, app_id: str):
(
self._session.query(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id == app_id)
@ -501,28 +524,35 @@ class WorkflowDraftVariableService:
self._session.delete(upload_file)
self._session.delete(variable_file)
def delete_node_variables(self, app_id: str, node_id: str):
return self._delete_node_variables(app_id, node_id)
def delete_node_variables(self, app_id: str, node_id: str, user_id: str):
return self._delete_node_variables(app_id, node_id, user_id=user_id)
def _delete_node_variables(self, app_id: str, node_id: str):
self._session.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
).delete()
def _delete_node_variables(self, app_id: str, node_id: str, user_id: str):
(
self._session.query(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
WorkflowDraftVariable.user_id == user_id,
)
.delete(synchronize_session=False)
)
def _get_conversation_id_from_draft_variable(self, app_id: str) -> str | None:
def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None:
draft_var = self._get_variable(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
name=str(SystemVariableKey.CONVERSATION_ID),
user_id=user_id,
)
if draft_var is None:
return None
segment = draft_var.get_value()
if not isinstance(segment, StringSegment):
logger.warning(
"sys.conversation_id variable is not a string: app_id=%s, id=%s",
"sys.conversation_id variable is not a string: app_id=%s, user_id=%s, id=%s",
app_id,
user_id,
draft_var.id,
)
return None
@ -543,7 +573,7 @@ class WorkflowDraftVariableService:
If no such conversation exists, a new conversation is created and its ID is returned.
"""
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id)
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id)
if conv_id is not None:
conversation = (
@ -580,12 +610,13 @@ class WorkflowDraftVariableService:
self._session.flush()
return conversation.id
def prefill_conversation_variable_default_values(self, workflow: Workflow):
def prefill_conversation_variable_default_values(self, workflow: Workflow, user_id: str):
""""""
draft_conv_vars: list[WorkflowDraftVariable] = []
for conv_var in workflow.conversation_variables:
draft_var = WorkflowDraftVariable.new_conversation_variable(
app_id=workflow.app_id,
user_id=user_id,
name=conv_var.name,
value=conv_var,
description=conv_var.description,
@ -635,7 +666,7 @@ def _batch_upsert_draft_variable(
stmt = pg_insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars])
if policy == _UpsertPolicy.OVERWRITE:
stmt = stmt.on_conflict_do_update(
index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
index_elements=WorkflowDraftVariable.unique_app_id_user_id_node_id_name(),
set_={
# Refresh creation timestamp to ensure updated variables
# appear first in chronologically sorted result sets.
@ -652,7 +683,9 @@ def _batch_upsert_draft_variable(
},
)
elif policy == _UpsertPolicy.IGNORE:
stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name())
stmt = stmt.on_conflict_do_nothing(
index_elements=WorkflowDraftVariable.unique_app_id_user_id_node_id_name()
)
else:
stmt = mysql_insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) # type: ignore[assignment]
if policy == _UpsertPolicy.OVERWRITE:
@ -682,6 +715,7 @@ def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
d: dict[str, Any] = {
"id": model.id,
"app_id": model.app_id,
"user_id": model.user_id,
"last_edited_at": None,
"node_id": model.node_id,
"name": model.name,
@ -807,6 +841,7 @@ class DraftVariableSaver:
def _create_dummy_output_variable(self):
return WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
user_id=self._user.id,
node_id=self._node_id,
name=self._DUMMY_OUTPUT_IDENTITY,
node_execution_id=self._node_execution_id,
@ -842,6 +877,7 @@ class DraftVariableSaver:
draft_vars.append(
WorkflowDraftVariable.new_conversation_variable(
app_id=self._app_id,
user_id=self._user.id,
name=item.name,
value=segment,
)
@ -862,6 +898,7 @@ class DraftVariableSaver:
draft_vars.append(
WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
user_id=self._user.id,
node_id=self._node_id,
name=name,
node_execution_id=self._node_execution_id,
@ -884,6 +921,7 @@ class DraftVariableSaver:
draft_vars.append(
WorkflowDraftVariable.new_sys_variable(
app_id=self._app_id,
user_id=self._user.id,
name=name,
node_execution_id=self._node_execution_id,
value=value_seg,
@ -1019,6 +1057,7 @@ class DraftVariableSaver:
# Create the draft variable
draft_var = WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
user_id=self._user.id,
node_id=self._node_id,
name=name,
node_execution_id=self._node_execution_id,
@ -1032,6 +1071,7 @@ class DraftVariableSaver:
# Create the draft variable
draft_var = WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
user_id=self._user.id,
node_id=self._node_id,
name=name,
node_execution_id=self._node_execution_id,

View File

@ -697,7 +697,7 @@ class WorkflowService:
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=account.id)
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
@ -740,6 +740,7 @@ class WorkflowService:
engine=db.engine,
app_id=app_model.id,
tenant_id=app_model.tenant_id,
user_id=account.id,
)
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
@ -831,6 +832,7 @@ class WorkflowService:
workflow=draft_workflow,
node_config=node_config,
manual_inputs=inputs or {},
user_id=account.id,
)
node = self._build_human_input_node(
workflow=draft_workflow,
@ -891,6 +893,7 @@ class WorkflowService:
workflow=draft_workflow,
node_config=node_config,
manual_inputs=inputs or {},
user_id=account.id,
)
node = self._build_human_input_node(
workflow=draft_workflow,
@ -967,6 +970,7 @@ class WorkflowService:
workflow=draft_workflow,
node_config=node_config,
manual_inputs=inputs or {},
user_id=account.id,
)
node = self._build_human_input_node(
workflow=draft_workflow,
@ -1102,10 +1106,11 @@ class WorkflowService:
workflow: Workflow,
node_config: NodeConfigDict,
manual_inputs: Mapping[str, Any],
user_id: str,
) -> VariablePool:
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id)
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
@ -1118,6 +1123,7 @@ class WorkflowService:
engine=db.engine,
app_id=app_model.id,
tenant_id=app_model.tenant_id,
user_id=user_id,
)
variable_mapping = HumanInputNode.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict,

View File

@ -30,6 +30,7 @@ from services.workflow_draft_variable_service import (
class TestWorkflowDraftVariableService(unittest.TestCase):
_test_app_id: str
_session: Session
_test_user_id: str
_node1_id = "test_node_1"
_node2_id = "test_node_2"
_node_exec_id = str(uuid.uuid4())
@ -99,13 +100,13 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test_list_variables(self):
srv = self._get_test_srv()
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2)
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2, user_id=self._test_user_id)
assert var_list.total == 5
assert len(var_list.variables) == 2
page1_var_ids = {v.id for v in var_list.variables}
assert page1_var_ids.issubset(self._variable_ids)
var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2)
var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2, user_id=self._test_user_id)
assert var_list_2.total is None
assert len(var_list_2.variables) == 2
page2_var_ids = {v.id for v in var_list_2.variables}
@ -114,7 +115,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test_get_node_variable(self):
srv = self._get_test_srv()
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var")
node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var", user_id=self._test_user_id)
assert node_var is not None
assert node_var.id == self._node1_str_var_id
assert node_var.name == "str_var"
@ -122,7 +123,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test_get_system_variable(self):
srv = self._get_test_srv()
sys_var = srv.get_system_variable(self._test_app_id, "sys_var")
sys_var = srv.get_system_variable(self._test_app_id, "sys_var", user_id=self._test_user_id)
assert sys_var is not None
assert sys_var.id == self._sys_var_id
assert sys_var.name == "sys_var"
@ -130,7 +131,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test_get_conversation_variable(self):
srv = self._get_test_srv()
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var")
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var", user_id=self._test_user_id)
assert conv_var is not None
assert conv_var.id == self._conv_var_id
assert conv_var.name == "conv_var"
@ -138,7 +139,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test_delete_node_variables(self):
srv = self._get_test_srv()
srv.delete_node_variables(self._test_app_id, self._node2_id)
srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id)
node2_var_count = (
self._session.query(WorkflowDraftVariable)
.where(
@ -162,7 +163,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test__list_node_variables(self):
srv = self._get_test_srv()
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id)
assert len(node_vars.variables) == 2
assert {v.id for v in node_vars.variables} == set(self._node2_var_ids)
@ -173,7 +174,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
[self._node2_id, "str_var"],
[self._node2_id, "int_var"],
]
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors)
variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors, user_id=self._test_user_id)
assert len(variables) == 3
assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids)
@ -206,19 +207,23 @@ class TestDraftVariableLoader(unittest.TestCase):
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._test_tenant_id = str(uuid.uuid4())
self._test_user_id = str(uuid.uuid4())
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
name="sys_var",
value=build_segment("sys_value"),
node_execution_id=self._node_exec_id,
)
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
name="conv_var",
value=build_segment("conv_value"),
)
node_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
@ -248,12 +253,22 @@ class TestDraftVariableLoader(unittest.TestCase):
session.commit()
def test_variable_loader_with_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=self._test_app_id,
tenant_id=self._test_tenant_id,
user_id=self._test_user_id,
)
variables = var_loader.load_variables([])
assert len(variables) == 0
def test_variable_loader_with_non_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=self._test_app_id,
tenant_id=self._test_tenant_id,
user_id=self._test_user_id,
)
variables = var_loader.load_variables(
[
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],
@ -296,7 +311,12 @@ class TestDraftVariableLoader(unittest.TestCase):
session.commit()
# Now test loading using DraftVarLoader
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=self._test_app_id,
tenant_id=self._test_tenant_id,
user_id=setup_account.id,
)
# Load the variable using the standard workflow
variables = var_loader.load_variables([["test_offload_node", "offloaded_string_var"]])
@ -313,7 +333,7 @@ class TestDraftVariableLoader(unittest.TestCase):
# Clean up - delete all draft variables for this app
with Session(bind=db.engine) as session:
service = WorkflowDraftVariableService(session)
service.delete_workflow_variables(self._test_app_id)
service.delete_app_workflow_variables(self._test_app_id)
session.commit()
def test_load_offloaded_variable_object_type_integration(self):
@ -364,6 +384,7 @@ class TestDraftVariableLoader(unittest.TestCase):
# Now create the offloaded draft variable with the correct file_id
offloaded_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id="test_offload_node",
name="offloaded_object_var",
value=build_segment({"truncated": True}),
@ -379,7 +400,9 @@ class TestDraftVariableLoader(unittest.TestCase):
# Use the service method that properly preloads relationships
service = WorkflowDraftVariableService(session)
draft_vars = service.get_draft_variables_by_selectors(
self._test_app_id, [["test_offload_node", "offloaded_object_var"]]
self._test_app_id,
[["test_offload_node", "offloaded_object_var"]],
user_id=self._test_user_id,
)
assert len(draft_vars) == 1
@ -387,7 +410,12 @@ class TestDraftVariableLoader(unittest.TestCase):
assert loaded_var.is_truncated()
# Create DraftVarLoader and test loading
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=self._test_app_id,
tenant_id=self._test_tenant_id,
user_id=self._test_user_id,
)
# Test the _load_offloaded_variable method
selector_tuple, variable = var_loader._load_offloaded_variable(loaded_var)
@ -459,6 +487,7 @@ class TestDraftVariableLoader(unittest.TestCase):
# Now create the offloaded draft variable with the correct file_id
offloaded_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id="test_integration_node",
name="offloaded_integration_var",
value=build_segment("truncated"),
@ -473,7 +502,12 @@ class TestDraftVariableLoader(unittest.TestCase):
# Test load_variables with both regular and offloaded variables
# This method should handle the relationship preloading internally
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=self._test_app_id,
tenant_id=self._test_tenant_id,
user_id=self._test_user_id,
)
variables = var_loader.load_variables(
[
@ -572,6 +606,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
# Create test variables
self._node_var_with_exec = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node_id,
name="test_var",
value=build_segment("old_value"),
@ -581,6 +616,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
self._node_var_without_exec = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node_id,
name="no_exec_var",
value=build_segment("some_value"),
@ -591,6 +627,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
self._node_var_missing_exec = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node_id,
name="missing_exec_var",
value=build_segment("some_value"),
@ -599,6 +636,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
self._conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
name="conv_var_1",
value=build_segment("old_conv_value"),
)
@ -764,6 +802,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
# Create a system variable
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
name="sys_var",
value=build_segment("sys_value"),
node_execution_id=self._node_exec_id,

View File

@ -122,6 +122,7 @@ class TestWorkflowDraftVariableService:
name,
value,
variable_type: DraftVariableType = DraftVariableType.CONVERSATION,
user_id: str | None = None,
fake=None,
):
"""
@ -144,10 +145,15 @@ class TestWorkflowDraftVariableService:
WorkflowDraftVariable: Created test variable instance with proper type configuration
"""
fake = fake or Faker()
if user_id is None:
app = db_session_with_containers.query(App).filter_by(id=app_id).first()
assert app is not None
user_id = app.created_by
if variable_type == "conversation":
# Create conversation variable using the appropriate factory method
variable = WorkflowDraftVariable.new_conversation_variable(
app_id=app_id,
user_id=user_id,
name=name,
value=value,
description=fake.text(max_nb_chars=20),
@ -156,6 +162,7 @@ class TestWorkflowDraftVariableService:
# Create system variable with editable flag and execution context
variable = WorkflowDraftVariable.new_sys_variable(
app_id=app_id,
user_id=user_id,
name=name,
value=value,
node_execution_id=fake.uuid4(),
@ -165,6 +172,7 @@ class TestWorkflowDraftVariableService:
# Create node variable with visibility and editability settings
variable = WorkflowDraftVariable.new_node_variable(
app_id=app_id,
user_id=user_id,
node_id=node_id,
name=name,
value=value,
@ -189,7 +197,13 @@ class TestWorkflowDraftVariableService:
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
test_value = StringSegment(value=fake.word())
variable = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
"test_var",
test_value,
user_id=app.created_by,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_variable = service.get_variable(variable.id)
@ -250,7 +264,7 @@ class TestWorkflowDraftVariableService:
["test_node_1", "var3"],
]
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors)
retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors, user_id=app.created_by)
assert len(retrieved_variables) == 3
var_names = [var.name for var in retrieved_variables]
assert "var1" in var_names
@ -288,7 +302,7 @@ class TestWorkflowDraftVariableService:
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_variables_without_values(app.id, page=1, limit=3)
result = service.list_variables_without_values(app.id, page=1, limit=3, user_id=app.created_by)
assert result.total == 5
assert len(result.variables) == 3
assert result.variables[0].created_at >= result.variables[1].created_at
@ -339,7 +353,7 @@ class TestWorkflowDraftVariableService:
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_node_variables(app.id, node_id)
result = service.list_node_variables(app.id, node_id, user_id=app.created_by)
assert len(result.variables) == 2
for var in result.variables:
assert var.node_id == node_id
@ -381,7 +395,7 @@ class TestWorkflowDraftVariableService:
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_conversation_variables(app.id)
result = service.list_conversation_variables(app.id, user_id=app.created_by)
assert len(result.variables) == 2
for var in result.variables:
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
@ -559,7 +573,7 @@ class TestWorkflowDraftVariableService:
assert len(app_variables) == 3
assert len(other_app_variables) == 1
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_workflow_variables(app.id)
service.delete_user_workflow_variables(app.id, user_id=app.created_by)
app_variables_after = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables_after = (
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
@ -567,6 +581,69 @@ class TestWorkflowDraftVariableService:
assert len(app_variables_after) == 0
assert len(other_app_variables_after) == 1
def test_draft_variables_are_isolated_between_users(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test draft variable isolation for different users in the same app.
This test verifies that:
1. Query APIs return only variables owned by the target user.
2. User-scoped deletion only removes variables for that user and keeps
other users' variables in the same app untouched.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
user_a = app.created_by
user_b = fake.uuid4()
# Use identical variable names on purpose to verify uniqueness scope includes user_id.
self._create_test_variable(
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
"shared_name",
StringSegment(value="value_a"),
user_id=user_a,
fake=fake,
)
self._create_test_variable(
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
"shared_name",
StringSegment(value="value_b"),
user_id=user_b,
fake=fake,
)
self._create_test_variable(
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
"only_a",
StringSegment(value="only_a"),
user_id=user_a,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
user_a_vars = service.list_conversation_variables(app.id, user_id=user_a)
user_b_vars = service.list_conversation_variables(app.id, user_id=user_b)
assert {v.name for v in user_a_vars.variables} == {"shared_name", "only_a"}
assert {v.name for v in user_b_vars.variables} == {"shared_name"}
service.delete_user_workflow_variables(app.id, user_id=user_a)
user_a_remaining = (
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, user_id=user_a).count()
)
user_b_remaining = (
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, user_id=user_b).count()
)
assert user_a_remaining == 0
assert user_b_remaining == 1
def test_delete_node_variables_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
@ -627,7 +704,7 @@ class TestWorkflowDraftVariableService:
assert len(other_node_variables) == 1
assert len(conv_variables) == 1
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_node_variables(app.id, node_id)
service.delete_node_variables(app.id, node_id, user_id=app.created_by)
target_node_variables_after = (
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
)
@ -675,7 +752,7 @@ class TestWorkflowDraftVariableService:
db_session_with_containers.commit()
service = WorkflowDraftVariableService(db_session_with_containers)
service.prefill_conversation_variable_default_values(workflow)
service.prefill_conversation_variable_default_values(workflow, user_id="00000000-0000-0000-0000-000000000001")
draft_variables = (
db_session_with_containers.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
@ -715,7 +792,7 @@ class TestWorkflowDraftVariableService:
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id, app.created_by)
assert retrieved_conv_id == conversation_id
def test_get_conversation_id_from_draft_variable_not_found(
@ -731,7 +808,7 @@ class TestWorkflowDraftVariableService:
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id, app.created_by)
assert retrieved_conv_id is None
def test_list_system_variables_success(
@ -772,7 +849,7 @@ class TestWorkflowDraftVariableService:
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_system_variables(app.id)
result = service.list_system_variables(app.id, user_id=app.created_by)
assert len(result.variables) == 2
for var in result.variables:
assert var.node_id == SYSTEM_VARIABLE_NODE_ID
@ -819,15 +896,15 @@ class TestWorkflowDraftVariableService:
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var")
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var", user_id=app.created_by)
assert retrieved_conv_var is not None
assert retrieved_conv_var.name == "test_conv_var"
assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID
retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var")
retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var", user_id=app.created_by)
assert retrieved_sys_var is not None
assert retrieved_sys_var.name == "test_sys_var"
assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID
retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var")
retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var", user_id=app.created_by)
assert retrieved_node_var is not None
assert retrieved_node_var.name == "test_node_var"
assert retrieved_node_var.node_id == "test_node"
@ -845,9 +922,14 @@ class TestWorkflowDraftVariableService:
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var")
retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var", user_id=app.created_by)
assert retrieved_conv_var is None
retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var")
retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var", user_id=app.created_by)
assert retrieved_sys_var is None
retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var")
retrieved_node_var = service.get_node_variable(
app.id,
"test_node",
"non_existent_node_var",
user_id=app.created_by,
)
assert retrieved_node_var is None

View File

@ -398,6 +398,7 @@ class TestWorkflowDraftVariableEndpoints:
method = _unwrap(api.get)
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1"))
class DummySession:
def __enter__(self):

View File

@ -234,6 +234,7 @@ class TestAdvancedChatAppGeneratorInternals:
captured: dict[str, object] = {}
prefill_calls: list[object] = []
var_loader = SimpleNamespace(loader="draft")
workflow = SimpleNamespace(id="workflow-id")
monkeypatch.setattr(
"core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config",
@ -260,8 +261,8 @@ class TestAdvancedChatAppGeneratorInternals:
def __init__(self, session):
_ = session
def prefill_conversation_variable_default_values(self, workflow):
prefill_calls.append(workflow)
def prefill_conversation_variable_default_values(self, workflow, user_id):
prefill_calls.append((workflow, user_id))
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService)
@ -273,7 +274,7 @@ class TestAdvancedChatAppGeneratorInternals:
result = generator.single_iteration_generate(
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
workflow=SimpleNamespace(id="workflow-id"),
workflow=workflow,
node_id="node-1",
user=SimpleNamespace(id="user-id"),
args={"inputs": {"foo": "bar"}},
@ -281,7 +282,7 @@ class TestAdvancedChatAppGeneratorInternals:
)
assert result == {"ok": True}
assert prefill_calls
assert prefill_calls == [(workflow, "user-id")]
assert captured["variable_loader"] is var_loader
assert captured["application_generate_entity"].single_iteration_run.node_id == "node-1"
@ -291,6 +292,7 @@ class TestAdvancedChatAppGeneratorInternals:
captured: dict[str, object] = {}
prefill_calls: list[object] = []
var_loader = SimpleNamespace(loader="draft")
workflow = SimpleNamespace(id="workflow-id")
monkeypatch.setattr(
"core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config",
@ -317,8 +319,8 @@ class TestAdvancedChatAppGeneratorInternals:
def __init__(self, session):
_ = session
def prefill_conversation_variable_default_values(self, workflow):
prefill_calls.append(workflow)
def prefill_conversation_variable_default_values(self, workflow, user_id):
prefill_calls.append((workflow, user_id))
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService)
@ -330,7 +332,7 @@ class TestAdvancedChatAppGeneratorInternals:
result = generator.single_loop_generate(
app_model=SimpleNamespace(id="app", tenant_id="tenant"),
workflow=SimpleNamespace(id="workflow-id"),
workflow=workflow,
node_id="node-2",
user=SimpleNamespace(id="user-id"),
args=SimpleNamespace(inputs={"foo": "bar"}),
@ -338,7 +340,7 @@ class TestAdvancedChatAppGeneratorInternals:
)
assert result == {"ok": True}
assert prefill_calls
assert prefill_calls == [(workflow, "user-id")]
assert captured["variable_loader"] is var_loader
assert captured["application_generate_entity"].single_loop_run.node_id == "node-2"

View File

@ -11,6 +11,7 @@ import unittest
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
from core.rag.models.document import Document
@ -19,6 +20,7 @@ class TestWeaviateVector(unittest.TestCase):
"""Tests for WeaviateVector class with focus on doc_type metadata handling."""
def setUp(self):
weaviate_vector_module._weaviate_client = None
self.config = WeaviateConfig(
endpoint="http://localhost:8080",
api_key="test-key",
@ -27,6 +29,9 @@ class TestWeaviateVector(unittest.TestCase):
self.collection_name = "Test_Collection_Node"
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
def tearDown(self):
weaviate_vector_module._weaviate_client = None
@patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate")
def _create_weaviate_vector(self, mock_weaviate_module):
"""Helper to create a WeaviateVector instance with mocked client."""

View File

@ -263,7 +263,7 @@ def test_import_app_completed_uses_declared_dependencies(monkeypatch):
assert result.status == ImportStatus.COMPLETED
assert result.app_id == "app-new"
draft_var_service.delete_workflow_variables.assert_called_once_with(app_id="app-new")
draft_var_service.delete_app_workflow_variables.assert_called_once_with(app_id="app-new")
@pytest.mark.parametrize("has_workflow", [True, False])
@ -305,7 +305,7 @@ def test_import_app_legacy_versions_extract_dependencies(monkeypatch, has_workfl
account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=_yaml_dump(data)
)
assert result.status == ImportStatus.COMPLETED_WITH_WARNINGS
draft_var_service.delete_workflow_variables.assert_called_once_with(app_id="app-legacy")
draft_var_service.delete_app_workflow_variables.assert_called_once_with(app_id="app-legacy")
def test_import_app_yaml_error_returns_failed(monkeypatch):

View File

@ -24,7 +24,11 @@ class TestDraftVarLoaderSimple:
def draft_var_loader(self, mock_engine):
"""Create DraftVarLoader instance for testing."""
return DraftVarLoader(
engine=mock_engine, app_id="test-app-id", tenant_id="test-tenant-id", fallback_variables=[]
engine=mock_engine,
app_id="test-app-id",
tenant_id="test-tenant-id",
user_id="test-user-id",
fallback_variables=[],
)
def test_load_offloaded_variable_string_type_unit(self, draft_var_loader):
@ -323,7 +327,9 @@ class TestDraftVarLoaderSimple:
# Verify service method was called
mock_service.get_draft_variables_by_selectors.assert_called_once_with(
draft_var_loader._app_id, selectors
draft_var_loader._app_id,
selectors,
user_id=draft_var_loader._user_id,
)
# Verify offloaded variable loading was called

View File

@ -8,7 +8,7 @@ from sqlalchemy import Engine
from sqlalchemy.orm import Session
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey
from dify_graph.variables.segments import StringSegment
from dify_graph.variables.types import SegmentType
from libs.uuid_utils import uuidv7
@ -182,6 +182,42 @@ class TestDraftVariableSaver:
draft_vars = mock_batch_upsert.call_args[0][1]
assert len(draft_vars) == 2
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True)
def test_start_node_save_persists_sys_timestamp_and_workflow_run_id(self, mock_batch_upsert):
"""Start node should persist common `sys.*` variables, not only `sys.files`."""
mock_session = MagicMock(spec=Session)
mock_user = MagicMock(spec=Account)
mock_user.id = "test-user-id"
mock_user.tenant_id = "test-tenant-id"
saver = DraftVariableSaver(
session=mock_session,
app_id="test-app-id",
node_id="start-node-id",
node_type=BuiltinNodeTypes.START,
node_execution_id="exec-id",
user=mock_user,
)
outputs = {
f"{SYSTEM_VARIABLE_NODE_ID}.{SystemVariableKey.TIMESTAMP}": 1700000000,
f"{SYSTEM_VARIABLE_NODE_ID}.{SystemVariableKey.WORKFLOW_EXECUTION_ID}": "run-id-123",
}
saver.save(outputs=outputs)
mock_batch_upsert.assert_called_once()
draft_vars = mock_batch_upsert.call_args[0][1]
# plus one dummy output because there are no non-sys Start inputs
assert len(draft_vars) == 3
sys_vars = [v for v in draft_vars if v.node_id == SYSTEM_VARIABLE_NODE_ID]
assert {v.name for v in sys_vars} == {
str(SystemVariableKey.TIMESTAMP),
str(SystemVariableKey.WORKFLOW_EXECUTION_ID),
}
class TestWorkflowDraftVariableService:
def _get_test_app_id(self):

View File

@ -245,6 +245,7 @@ class TestWorkflowService:
workflow=workflow,
node_config=node_config,
manual_inputs={"#node-0.result#": "LLM output"},
user_id="account-1",
)
node.render_form_content_with_outputs.assert_called_once()

View File

@ -0,0 +1,139 @@
import {
getChangedBranchCoverage,
getChangedStatementCoverage,
getIgnoredChangedLinesFromSource,
normalizeToRepoRelative,
parseChangedLineMap,
} from '../scripts/check-components-diff-coverage-lib.mjs'
describe('check-components-diff-coverage helpers', () => {
it('should parse changed line maps from unified diffs', () => {
const diff = [
'diff --git a/web/app/components/share/a.ts b/web/app/components/share/a.ts',
'+++ b/web/app/components/share/a.ts',
'@@ -10,0 +11,2 @@',
'+const a = 1',
'+const b = 2',
'diff --git a/web/app/components/base/b.ts b/web/app/components/base/b.ts',
'+++ b/web/app/components/base/b.ts',
'@@ -20 +21 @@',
'+const c = 3',
'diff --git a/web/README.md b/web/README.md',
'+++ b/web/README.md',
'@@ -1 +1 @@',
'+ignore me',
].join('\n')
const lineMap = parseChangedLineMap(diff, (filePath: string) => filePath.startsWith('web/app/components/'))
expect([...lineMap.entries()]).toEqual([
['web/app/components/share/a.ts', new Set([11, 12])],
['web/app/components/base/b.ts', new Set([21])],
])
})
it('should normalize coverage and absolute paths to repo-relative paths', () => {
const repoRoot = '/repo'
const webRoot = '/repo/web'
expect(normalizeToRepoRelative('web/app/components/share/a.ts', {
appComponentsCoveragePrefix: 'app/components/',
appComponentsPrefix: 'web/app/components/',
repoRoot,
sharedTestPrefix: 'web/__tests__/',
webRoot,
})).toBe('web/app/components/share/a.ts')
expect(normalizeToRepoRelative('app/components/share/a.ts', {
appComponentsCoveragePrefix: 'app/components/',
appComponentsPrefix: 'web/app/components/',
repoRoot,
sharedTestPrefix: 'web/__tests__/',
webRoot,
})).toBe('web/app/components/share/a.ts')
expect(normalizeToRepoRelative('/repo/web/app/components/share/a.ts', {
appComponentsCoveragePrefix: 'app/components/',
appComponentsPrefix: 'web/app/components/',
repoRoot,
sharedTestPrefix: 'web/__tests__/',
webRoot,
})).toBe('web/app/components/share/a.ts')
})
it('should calculate changed statement coverage from changed lines', () => {
const entry = {
s: { 0: 1, 1: 0 },
statementMap: {
0: { start: { line: 10 }, end: { line: 10 } },
1: { start: { line: 12 }, end: { line: 13 } },
},
}
const coverage = getChangedStatementCoverage(entry, new Set([10, 12]))
expect(coverage).toEqual({
covered: 1,
total: 2,
uncoveredLines: [12],
})
})
it('should fail changed lines when a source file has no coverage entry', () => {
const coverage = getChangedStatementCoverage(undefined, new Set([42, 43]))
expect(coverage).toEqual({
covered: 0,
total: 2,
uncoveredLines: [42, 43],
})
})
it('should calculate changed branch coverage using changed branch definitions', () => {
const entry = {
b: {
0: [1, 0],
},
branchMap: {
0: {
line: 20,
loc: { start: { line: 20 }, end: { line: 20 } },
locations: [
{ start: { line: 20 }, end: { line: 20 } },
{ start: { line: 21 }, end: { line: 21 } },
],
type: 'if',
},
},
}
const coverage = getChangedBranchCoverage(entry, new Set([20]))
expect(coverage).toEqual({
covered: 1,
total: 2,
uncoveredBranches: [
{ armIndex: 1, line: 21 },
],
})
})
it('should ignore changed lines with valid pragma reasons and report invalid pragmas', () => {
const sourceCode = [
'const a = 1',
'const b = 2 // diff-coverage-ignore-line: defensive fallback',
'const c = 3 // diff-coverage-ignore-line:',
'const d = 4 // diff-coverage-ignore-line: not changed',
].join('\n')
const result = getIgnoredChangedLinesFromSource(sourceCode, new Set([2, 3]))
expect([...result.effectiveChangedLines]).toEqual([3])
expect([...result.ignoredLines.entries()]).toEqual([
[2, 'defensive fallback'],
])
expect(result.invalidPragmas).toEqual([
{ line: 3, reason: 'missing ignore reason' },
])
})
})

View File

@ -275,7 +275,7 @@ describe('useTextGenerationBatch', () => {
})
act(() => {
result.current.handleCompleted({ answer: 'failed' } as unknown as string, 1, false)
result.current.handleCompleted('{"answer":"failed"}', 1, false)
})
expect(result.current.allFailedTaskList).toEqual([
@ -291,7 +291,7 @@ describe('useTextGenerationBatch', () => {
{
'Name': 'Alice',
'Score': '',
'generation.completionResult': JSON.stringify({ answer: 'failed' }),
'generation.completionResult': '{"answer":"failed"}',
},
])

View File

@ -241,10 +241,7 @@ export const useTextGenerationBatch = ({
result[variable.name] = String(task.params.inputs[variable.key] ?? '')
})
let completionValue = batchCompletionMap[String(task.id)]
if (typeof completionValue === 'object')
completionValue = JSON.stringify(completionValue)
const completionValue = batchCompletionMap[String(task.id)] ?? ''
result[t('generation.completionResult', { ns: 'share' })] = completionValue
return result
})

View File

@ -0,0 +1,334 @@
import type { PromptConfig } from '@/models/debug'
import type { SiteInfo } from '@/models/share'
import type { IOtherOptions } from '@/service/base'
import type { VisionSettings } from '@/types/app'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { AppSourceType } from '@/service/share'
import { Resolution, TransferMethod } from '@/types/app'
import Result from '../index'
const {
notifyMock,
sendCompletionMessageMock,
sendWorkflowMessageMock,
stopChatMessageRespondingMock,
textGenerationResPropsSpy,
} = vi.hoisted(() => ({
notifyMock: vi.fn(),
sendCompletionMessageMock: vi.fn(),
sendWorkflowMessageMock: vi.fn(),
stopChatMessageRespondingMock: vi.fn(),
textGenerationResPropsSpy: vi.fn(),
}))
vi.mock('i18next', () => ({
t: (key: string) => key,
}))
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: notifyMock,
},
}))
vi.mock('@/utils', async () => {
const actual = await vi.importActual<typeof import('@/utils')>('@/utils')
return {
...actual,
sleep: () => new Promise<void>(() => {}),
}
})
vi.mock('@/service/share', async () => {
const actual = await vi.importActual<typeof import('@/service/share')>('@/service/share')
return {
...actual,
sendCompletionMessage: (...args: Parameters<typeof actual.sendCompletionMessage>) => sendCompletionMessageMock(...args),
sendWorkflowMessage: (...args: Parameters<typeof actual.sendWorkflowMessage>) => sendWorkflowMessageMock(...args),
stopChatMessageResponding: (...args: Parameters<typeof actual.stopChatMessageResponding>) => stopChatMessageRespondingMock(...args),
}
})
vi.mock('@/app/components/app/text-generate/item', () => ({
default: (props: Record<string, unknown>) => {
textGenerationResPropsSpy(props)
return (
<div data-testid="text-generation-res">
{typeof props.content === 'string' ? props.content : JSON.stringify(props.content ?? null)}
</div>
)
},
}))
vi.mock('@/app/components/share/text-generation/no-data', () => ({
default: () => <div data-testid="no-data">No data</div>,
}))
const promptConfig: PromptConfig = {
prompt_template: 'template',
prompt_variables: [
{ key: 'name', name: 'Name', type: 'string', required: true },
],
}
const siteInfo: SiteInfo = {
title: 'Share title',
description: 'Share description',
icon_type: 'emoji',
icon: 'robot',
}
const visionConfig: VisionSettings = {
enabled: false,
number_limits: 2,
detail: Resolution.low,
transfer_methods: [TransferMethod.local_file],
}
const baseProps = {
appId: 'app-1',
appSourceType: AppSourceType.webApp,
completionFiles: [],
controlRetry: 0,
controlSend: 0,
controlStopResponding: 0,
handleSaveMessage: vi.fn(),
inputs: { name: 'Alice' },
isCallBatchAPI: false,
isError: false,
isMobile: false,
isPC: true,
isShowTextToSpeech: true,
isWorkflow: false,
moreLikeThisEnabled: true,
onCompleted: vi.fn(),
onRunControlChange: vi.fn(),
onRunStart: vi.fn(),
onShowRes: vi.fn(),
promptConfig,
siteInfo,
visionConfig,
}
describe('Result', () => {
beforeEach(() => {
vi.clearAllMocks()
stopChatMessageRespondingMock.mockResolvedValue(undefined)
})
it('should render no data before the first execution', () => {
render(<Result {...baseProps} />)
expect(screen.getByTestId('no-data')).toBeTruthy()
expect(screen.queryByTestId('text-generation-res')).toBeNull()
})
it('should stream completion results and stop the current task', async () => {
let completionHandlers: {
onCompleted: () => void
onData: (chunk: string, isFirstMessage: boolean, info: { messageId: string, taskId?: string }) => void
onError: () => void
onMessageReplace: (messageReplace: { answer: string }) => void
} | null = null
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
completionHandlers = handlers
})
const onCompleted = vi.fn()
const onRunControlChange = vi.fn()
const { rerender } = render(
<Result
{...baseProps}
onCompleted={onCompleted}
onRunControlChange={onRunControlChange}
/>,
)
rerender(
<Result
{...baseProps}
controlSend={1}
onCompleted={onCompleted}
onRunControlChange={onRunControlChange}
/>,
)
expect(sendCompletionMessageMock).toHaveBeenCalledTimes(1)
expect(screen.getByRole('status', { name: 'appApi.loading' })).toBeTruthy()
await act(async () => {
completionHandlers?.onData('Hello', false, {
messageId: 'message-1',
taskId: 'task-1',
})
})
expect(screen.getByTestId('text-generation-res').textContent).toContain('Hello')
await waitFor(() => {
expect(onRunControlChange).toHaveBeenLastCalledWith(expect.objectContaining({
isStopping: false,
}))
})
fireEvent.click(screen.getByRole('button', { name: 'operation.stopResponding' }))
await waitFor(() => {
expect(stopChatMessageRespondingMock).toHaveBeenCalledWith('app-1', 'task-1', AppSourceType.webApp, 'app-1')
})
await act(async () => {
completionHandlers?.onCompleted()
})
expect(onCompleted).toHaveBeenCalledWith('Hello', undefined, true)
expect(textGenerationResPropsSpy).toHaveBeenLastCalledWith(expect.objectContaining({
messageId: 'message-1',
}))
})
it('should render workflow results after workflow completion', async () => {
let workflowHandlers: IOtherOptions | null = null
sendWorkflowMessageMock.mockImplementation(async (_data, handlers) => {
workflowHandlers = handlers
})
const onCompleted = vi.fn()
const { rerender } = render(
<Result
{...baseProps}
isWorkflow
onCompleted={onCompleted}
/>,
)
rerender(
<Result
{...baseProps}
isWorkflow
controlSend={1}
onCompleted={onCompleted}
/>,
)
await act(async () => {
workflowHandlers?.onWorkflowStarted?.({
workflow_run_id: 'run-1',
task_id: 'task-1',
event: 'workflow_started',
data: {
id: 'run-1',
workflow_id: 'wf-1',
created_at: 0,
},
})
workflowHandlers?.onTextChunk?.({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'text_chunk',
data: {
text: 'Hello',
},
})
workflowHandlers?.onWorkflowFinished?.({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'workflow_finished',
data: {
id: 'run-1',
workflow_id: 'wf-1',
status: 'succeeded',
outputs: {
answer: 'Hello',
},
error: '',
elapsed_time: 0,
total_tokens: 0,
total_steps: 0,
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
},
})
})
expect(screen.getByTestId('text-generation-res').textContent).toContain('{"answer":"Hello"}')
expect(textGenerationResPropsSpy).toHaveBeenLastCalledWith(expect.objectContaining({
workflowProcessData: expect.objectContaining({
resultText: 'Hello',
status: 'succeeded',
}),
}))
expect(onCompleted).toHaveBeenCalledWith('{"answer":"Hello"}', undefined, true)
})
it('should render batch task ids for both short and long indexes', () => {
const { rerender } = render(
<Result
{...baseProps}
isCallBatchAPI
taskId={3}
/>,
)
expect(textGenerationResPropsSpy).toHaveBeenLastCalledWith(expect.objectContaining({
taskId: '03',
}))
rerender(
<Result
{...baseProps}
isCallBatchAPI
taskId={12}
/>,
)
expect(textGenerationResPropsSpy).toHaveBeenLastCalledWith(expect.objectContaining({
taskId: '12',
}))
})
it('should render the mobile stop button layout while a batch run is responding', async () => {
let completionHandlers: {
onData: (chunk: string, isFirstMessage: boolean, info: { messageId: string, taskId?: string }) => void
} | null = null
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
completionHandlers = handlers
})
const { rerender } = render(
<Result
{...baseProps}
isCallBatchAPI
isMobile
isPC={false}
taskId={2}
/>,
)
rerender(
<Result
{...baseProps}
controlSend={1}
isCallBatchAPI
isMobile
isPC={false}
taskId={2}
/>,
)
await act(async () => {
completionHandlers?.onData('Hello', false, {
messageId: 'message-batch',
taskId: 'task-batch',
})
})
expect(screen.getByRole('button', { name: 'operation.stopResponding' }).parentElement?.className).toContain('justify-center')
})
})

View File

@ -0,0 +1,293 @@
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { PromptConfig } from '@/models/debug'
import type { VisionFile, VisionSettings } from '@/types/app'
import { Resolution, TransferMethod } from '@/types/app'
import { buildResultRequestData, validateResultRequest } from '../result-request'
const createTranslator = () => vi.fn((key: string) => key)
const createFileEntity = (overrides: Partial<FileEntity> = {}): FileEntity => ({
id: 'file-1',
name: 'example.txt',
size: 128,
type: 'text/plain',
progress: 100,
transferMethod: TransferMethod.local_file,
supportFileType: 'document',
uploadedId: 'uploaded-1',
url: 'https://example.com/file.txt',
...overrides,
})
const createVisionFile = (overrides: Partial<VisionFile> = {}): VisionFile => ({
type: 'image',
transfer_method: TransferMethod.local_file,
upload_file_id: 'upload-1',
url: 'https://example.com/image.png',
...overrides,
})
const promptConfig: PromptConfig = {
prompt_template: 'template',
prompt_variables: [
{ key: 'name', name: 'Name', type: 'string', required: true },
{ key: 'enabled', name: 'Enabled', type: 'boolean', required: true },
{ key: 'file', name: 'File', type: 'file', required: false },
{ key: 'files', name: 'Files', type: 'file-list', required: false },
],
}
const visionConfig: VisionSettings = {
enabled: true,
number_limits: 2,
detail: Resolution.low,
transfer_methods: [TransferMethod.local_file],
}
describe('result-request', () => {
it('should reject missing required non-boolean inputs', () => {
const t = createTranslator()
const result = validateResultRequest({
completionFiles: [],
inputs: {
enabled: false,
},
isCallBatchAPI: false,
promptConfig,
t,
})
expect(result).toEqual({
canSend: false,
notification: {
type: 'error',
message: 'errorMessage.valueOfVarRequired',
},
})
})
it('should allow required number inputs with a value of zero', () => {
const result = validateResultRequest({
completionFiles: [],
inputs: {
count: 0,
},
isCallBatchAPI: false,
promptConfig: {
prompt_template: 'template',
prompt_variables: [
{ key: 'count', name: 'Count', type: 'number', required: true },
],
},
t: createTranslator(),
})
expect(result).toEqual({ canSend: true })
})
it('should reject required text inputs that only contain whitespace', () => {
const result = validateResultRequest({
completionFiles: [],
inputs: {
name: ' ',
},
isCallBatchAPI: false,
promptConfig: {
prompt_template: 'template',
prompt_variables: [
{ key: 'name', name: 'Name', type: 'string', required: true },
],
},
t: createTranslator(),
})
expect(result).toEqual({
canSend: false,
notification: {
type: 'error',
message: 'errorMessage.valueOfVarRequired',
},
})
})
it('should reject required file lists when no files are selected', () => {
const result = validateResultRequest({
completionFiles: [],
inputs: {
files: [],
},
isCallBatchAPI: false,
promptConfig: {
prompt_template: 'template',
prompt_variables: [
{ key: 'files', name: 'Files', type: 'file-list', required: true },
],
},
t: createTranslator(),
})
expect(result).toEqual({
canSend: false,
notification: {
type: 'error',
message: 'errorMessage.valueOfVarRequired',
},
})
})
it('should allow required file inputs when a file is selected', () => {
const result = validateResultRequest({
completionFiles: [],
inputs: {
file: createFileEntity(),
},
isCallBatchAPI: false,
promptConfig: {
prompt_template: 'template',
prompt_variables: [
{ key: 'file', name: 'File', type: 'file', required: true },
],
},
t: createTranslator(),
})
expect(result).toEqual({ canSend: true })
})
it('should reject pending local uploads outside batch mode', () => {
const t = createTranslator()
const result = validateResultRequest({
completionFiles: [
createVisionFile({ upload_file_id: '' }),
],
inputs: {
name: 'Alice',
},
isCallBatchAPI: false,
promptConfig,
t,
})
expect(result).toEqual({
canSend: false,
notification: {
type: 'info',
message: 'errorMessage.waitForFileUpload',
},
})
})
it('should handle missing prompt metadata with and without pending uploads', () => {
const t = createTranslator()
const blocked = validateResultRequest({
completionFiles: [
createVisionFile({ upload_file_id: '' }),
],
inputs: {},
isCallBatchAPI: false,
promptConfig: null,
t,
})
const allowed = validateResultRequest({
completionFiles: [],
inputs: {},
isCallBatchAPI: false,
promptConfig: null,
t,
})
expect(blocked).toEqual({
canSend: false,
notification: {
type: 'info',
message: 'errorMessage.waitForFileUpload',
},
})
expect(allowed).toEqual({ canSend: true })
})
it('should skip validation in batch mode', () => {
const result = validateResultRequest({
completionFiles: [
createVisionFile({ upload_file_id: '' }),
],
inputs: {},
isCallBatchAPI: true,
promptConfig,
t: createTranslator(),
})
expect(result).toEqual({ canSend: true })
})
it('should build request data for single and list file inputs', () => {
const file = createFileEntity()
const secondFile = createFileEntity({
id: 'file-2',
name: 'second.txt',
uploadedId: 'uploaded-2',
url: 'https://example.com/second.txt',
})
const result = buildResultRequestData({
completionFiles: [
createVisionFile(),
createVisionFile({
transfer_method: TransferMethod.remote_url,
upload_file_id: '',
url: 'https://example.com/remote.png',
}),
],
inputs: {
enabled: true,
file,
files: [file, secondFile],
name: 'Alice',
},
promptConfig,
visionConfig,
})
expect(result).toEqual({
files: [
expect.objectContaining({
transfer_method: TransferMethod.local_file,
upload_file_id: 'upload-1',
url: '',
}),
expect.objectContaining({
transfer_method: TransferMethod.remote_url,
url: 'https://example.com/remote.png',
}),
],
inputs: {
enabled: true,
file: {
type: 'document',
transfer_method: TransferMethod.local_file,
upload_file_id: 'uploaded-1',
url: 'https://example.com/file.txt',
},
files: [
{
type: 'document',
transfer_method: TransferMethod.local_file,
upload_file_id: 'uploaded-1',
url: 'https://example.com/file.txt',
},
{
type: 'document',
transfer_method: TransferMethod.local_file,
upload_file_id: 'uploaded-2',
url: 'https://example.com/second.txt',
},
],
name: 'Alice',
},
})
})
})

View File

@ -0,0 +1,901 @@
import type { WorkflowProcess } from '@/app/components/base/chat/types'
import type { IOtherOptions } from '@/service/base'
import type { HumanInputFormData, HumanInputFormTimeoutData, NodeTracing } from '@/types/workflow'
import { act } from '@testing-library/react'
import { BlockEnum, NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types'
import {
appendParallelNext,
appendParallelStart,
appendResultText,
applyWorkflowFinishedState,
applyWorkflowOutputs,
applyWorkflowPaused,
createWorkflowStreamHandlers,
finishParallelTrace,
finishWorkflowNode,
markNodesStopped,
replaceResultText,
updateHumanInputFilled,
updateHumanInputRequired,
updateHumanInputTimeout,
upsertWorkflowNode,
} from '../workflow-stream-handlers'
const sseGetMock = vi.fn()
type TraceOverrides = Omit<Partial<NodeTracing>, 'execution_metadata'> & {
execution_metadata?: Partial<NonNullable<NodeTracing['execution_metadata']>>
}
vi.mock('@/service/base', async () => {
const actual = await vi.importActual<typeof import('@/service/base')>('@/service/base')
return {
...actual,
sseGet: (...args: Parameters<typeof actual.sseGet>) => sseGetMock(...args),
}
})
const createTrace = (overrides: TraceOverrides = {}): NodeTracing => {
const { execution_metadata, ...restOverrides } = overrides
return {
id: 'trace-1',
index: 0,
predecessor_node_id: '',
node_id: 'node-1',
node_type: BlockEnum.LLM,
title: 'Node',
inputs: {},
inputs_truncated: false,
process_data: {},
process_data_truncated: false,
outputs: {},
outputs_truncated: false,
status: NodeRunningStatus.Running,
elapsed_time: 0,
metadata: {
iterator_length: 0,
iterator_index: 0,
loop_length: 0,
loop_index: 0,
},
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
details: [[]],
execution_metadata: {
total_tokens: 0,
total_price: 0,
currency: 'USD',
...execution_metadata,
},
...restOverrides,
}
}
const createWorkflowProcess = (): WorkflowProcess => ({
status: WorkflowRunningStatus.Running,
tracing: [],
expand: false,
resultText: '',
})
const createHumanInput = (overrides: Partial<HumanInputFormData> = {}): HumanInputFormData => ({
form_id: 'form-1',
node_id: 'node-1',
node_title: 'Node',
form_content: 'content',
inputs: [],
actions: [],
form_token: 'token-1',
resolved_default_values: {},
display_in_ui: true,
expiration_time: 100,
...overrides,
})
describe('workflow-stream-handlers helpers', () => {
it('should update tracing, result text, and human input state', () => {
const parallelTrace = createTrace({
node_id: 'parallel-node',
execution_metadata: { parallel_id: 'parallel-1' },
details: [[]],
})
let workflowProcessData = appendParallelStart(undefined, parallelTrace)
workflowProcessData = appendParallelNext(workflowProcessData, parallelTrace)
workflowProcessData = finishParallelTrace(workflowProcessData, createTrace({
node_id: 'parallel-node',
execution_metadata: { parallel_id: 'parallel-1' },
error: 'failed',
}))
workflowProcessData = upsertWorkflowNode(workflowProcessData, createTrace({
node_id: 'node-1',
execution_metadata: { parallel_id: 'parallel-2' },
}))!
workflowProcessData = appendResultText(workflowProcessData, 'Hello ')
workflowProcessData = replaceResultText(workflowProcessData, 'Hello world')
workflowProcessData = updateHumanInputRequired(workflowProcessData, createHumanInput())
workflowProcessData = updateHumanInputFilled(workflowProcessData, {
action_id: 'action-1',
action_text: 'Submit',
node_id: 'node-1',
node_title: 'Node',
rendered_content: 'Done',
})
workflowProcessData = updateHumanInputTimeout(workflowProcessData, {
node_id: 'node-1',
node_title: 'Node',
expiration_time: 200,
} satisfies HumanInputFormTimeoutData)
workflowProcessData = applyWorkflowPaused(workflowProcessData)
expect(workflowProcessData.expand).toBe(false)
expect(workflowProcessData.resultText).toBe('Hello world')
expect(workflowProcessData.humanInputFilledFormDataList).toEqual([
expect.objectContaining({
action_text: 'Submit',
}),
])
expect(workflowProcessData.tracing[0]).toEqual(expect.objectContaining({
node_id: 'parallel-node',
expand: true,
}))
})
it('should initialize missing parallel details on start and next events', () => {
const parallelTrace = createTrace({
node_id: 'parallel-node',
execution_metadata: { parallel_id: 'parallel-1' },
})
const startedProcess = appendParallelStart(undefined, parallelTrace)
const nextProcess = appendParallelNext(startedProcess, parallelTrace)
expect(startedProcess.tracing[0]?.details).toEqual([[]])
expect(nextProcess.tracing[0]?.details).toEqual([[], []])
})
it('should leave tracing unchanged when a parallel next event has no matching trace', () => {
const process = createWorkflowProcess()
process.tracing = [
createTrace({
node_id: 'parallel-node',
execution_metadata: { parallel_id: 'parallel-1' },
details: [[]],
}),
]
const nextProcess = appendParallelNext(process, createTrace({
node_id: 'missing-node',
execution_metadata: { parallel_id: 'parallel-2' },
}))
expect(nextProcess.tracing).toEqual(process.tracing)
expect(nextProcess.expand).toBe(true)
})
it('should mark running nodes as stopped recursively', () => {
const workflowProcessData = createWorkflowProcess()
workflowProcessData.tracing = [
createTrace({
status: NodeRunningStatus.Running,
details: [[createTrace({ status: NodeRunningStatus.Waiting })]],
}),
]
const stoppedWorkflow = applyWorkflowFinishedState(workflowProcessData, WorkflowRunningStatus.Stopped)
markNodesStopped(stoppedWorkflow.tracing)
expect(stoppedWorkflow.status).toBe(WorkflowRunningStatus.Stopped)
expect(stoppedWorkflow.tracing[0].status).toBe(NodeRunningStatus.Stopped)
expect(stoppedWorkflow.tracing[0].details?.[0][0].status).toBe(NodeRunningStatus.Stopped)
})
it('should cover unmatched and replacement helper branches', () => {
const process = createWorkflowProcess()
process.tracing = [
createTrace({
node_id: 'node-1',
parallel_id: 'parallel-1',
extras: {
source: 'extra',
},
status: NodeRunningStatus.Succeeded,
}),
]
process.humanInputFormDataList = [
createHumanInput({ node_id: 'node-1' }),
]
process.humanInputFilledFormDataList = [
{
action_id: 'action-0',
action_text: 'Existing',
node_id: 'node-0',
node_title: 'Node 0',
rendered_content: 'Existing',
},
]
const parallelMatched = appendParallelNext(process, createTrace({
node_id: 'node-1',
execution_metadata: {
parallel_id: 'parallel-1',
},
}))
const notFinished = finishParallelTrace(process, createTrace({
node_id: 'missing',
execution_metadata: {
parallel_id: 'parallel-missing',
},
}))
const ignoredIteration = upsertWorkflowNode(process, createTrace({
iteration_id: 'iteration-1',
}))
const replacedNode = upsertWorkflowNode(process, createTrace({
node_id: 'node-1',
}))
const ignoredFinish = finishWorkflowNode(process, createTrace({
loop_id: 'loop-1',
}))
const unmatchedFinish = finishWorkflowNode(process, createTrace({
node_id: 'missing',
execution_metadata: {
parallel_id: 'missing',
},
}))
const finishedWithExtras = finishWorkflowNode(process, createTrace({
node_id: 'node-1',
execution_metadata: {
parallel_id: 'parallel-1',
},
error: 'failed',
}))
const succeededWorkflow = applyWorkflowFinishedState(process, WorkflowRunningStatus.Succeeded)
const outputlessWorkflow = applyWorkflowOutputs(undefined, null)
const updatedHumanInput = updateHumanInputRequired(process, createHumanInput({
node_id: 'node-1',
expiration_time: 300,
}))
const appendedHumanInput = updateHumanInputRequired(process, createHumanInput({
node_id: 'node-2',
}))
const noListFilled = updateHumanInputFilled(undefined, {
action_id: 'action-1',
action_text: 'Submit',
node_id: 'node-1',
node_title: 'Node',
rendered_content: 'Done',
})
const appendedFilled = updateHumanInputFilled(process, {
action_id: 'action-2',
action_text: 'Append',
node_id: 'node-2',
node_title: 'Node 2',
rendered_content: 'More',
})
const timeoutWithoutList = updateHumanInputTimeout(undefined, {
node_id: 'node-1',
node_title: 'Node',
expiration_time: 200,
})
const timeoutWithMatch = updateHumanInputTimeout(process, {
node_id: 'node-1',
node_title: 'Node',
expiration_time: 400,
})
markNodesStopped(undefined)
expect(parallelMatched.tracing[0].details).toHaveLength(2)
expect(notFinished).toEqual(expect.objectContaining({
expand: true,
tracing: process.tracing,
}))
expect(ignoredIteration).toEqual(process)
expect(replacedNode?.tracing[0]).toEqual(expect.objectContaining({
node_id: 'node-1',
status: NodeRunningStatus.Running,
}))
expect(ignoredFinish).toEqual(process)
expect(unmatchedFinish).toEqual(process)
expect(finishedWithExtras?.tracing[0]).toEqual(expect.objectContaining({
extras: {
source: 'extra',
},
error: 'failed',
}))
expect(succeededWorkflow.status).toBe(WorkflowRunningStatus.Succeeded)
expect(outputlessWorkflow.files).toEqual([])
expect(updatedHumanInput.humanInputFormDataList?.[0].expiration_time).toBe(300)
expect(appendedHumanInput.humanInputFormDataList).toHaveLength(2)
expect(noListFilled.humanInputFilledFormDataList).toHaveLength(1)
expect(appendedFilled.humanInputFilledFormDataList).toHaveLength(2)
expect(timeoutWithoutList).toEqual(expect.objectContaining({
status: WorkflowRunningStatus.Running,
tracing: [],
}))
expect(timeoutWithMatch.humanInputFormDataList?.[0].expiration_time).toBe(400)
})
})
describe('createWorkflowStreamHandlers', () => {
beforeEach(() => {
vi.clearAllMocks()
})
const setupHandlers = (overrides: { isTimedOut?: () => boolean } = {}) => {
let completionRes = ''
let currentTaskId: string | null = null
let isStopping = false
let messageId: string | null = null
let workflowProcessData: WorkflowProcess | undefined
const setCurrentTaskId = vi.fn((value: string | null | ((prev: string | null) => string | null)) => {
currentTaskId = typeof value === 'function' ? value(currentTaskId) : value
})
const setIsStopping = vi.fn((value: boolean | ((prev: boolean) => boolean)) => {
isStopping = typeof value === 'function' ? value(isStopping) : value
})
const setMessageId = vi.fn((value: string | null | ((prev: string | null) => string | null)) => {
messageId = typeof value === 'function' ? value(messageId) : value
})
const setWorkflowProcessData = vi.fn((value: WorkflowProcess | undefined) => {
workflowProcessData = value
})
const setCompletionRes = vi.fn((value: string) => {
completionRes = value
})
const notify = vi.fn()
const onCompleted = vi.fn()
const resetRunState = vi.fn()
const setRespondingFalse = vi.fn()
const markEnded = vi.fn()
const handlers = createWorkflowStreamHandlers({
getCompletionRes: () => completionRes,
getWorkflowProcessData: () => workflowProcessData,
isTimedOut: overrides.isTimedOut ?? (() => false),
markEnded,
notify,
onCompleted,
resetRunState,
setCompletionRes,
setCurrentTaskId,
setIsStopping,
setMessageId,
setRespondingFalse,
setWorkflowProcessData,
t: (key: string) => key,
taskId: 3,
})
return {
currentTaskId: () => currentTaskId,
handlers,
isStopping: () => isStopping,
messageId: () => messageId,
notify,
onCompleted,
resetRunState,
setCompletionRes,
setCurrentTaskId,
setMessageId,
setRespondingFalse,
workflowProcessData: () => workflowProcessData,
}
}
it('should process workflow success and paused events', () => {
const setup = setupHandlers()
const handlers = setup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onTextChunk' | 'onHumanInputRequired' | 'onHumanInputFormFilled' | 'onHumanInputFormTimeout' | 'onWorkflowPaused' | 'onWorkflowFinished' | 'onNodeStarted' | 'onNodeFinished' | 'onIterationStart' | 'onIterationNext' | 'onIterationFinish' | 'onLoopStart' | 'onLoopNext' | 'onLoopFinish'>>
act(() => {
handlers.onWorkflowStarted({
workflow_run_id: 'run-1',
task_id: 'task-1',
event: 'workflow_started',
data: { id: 'run-1', workflow_id: 'wf-1', created_at: 0 },
})
handlers.onNodeStarted({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'node_started',
data: createTrace({ node_id: 'node-1' }),
})
handlers.onNodeFinished({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'node_finished',
data: createTrace({ node_id: 'node-1', error: '' }),
})
handlers.onIterationStart({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'iteration_start',
data: createTrace({
node_id: 'iter-1',
execution_metadata: { parallel_id: 'parallel-1' },
details: [[]],
}),
})
handlers.onIterationNext({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'iteration_next',
data: createTrace({
node_id: 'iter-1',
execution_metadata: { parallel_id: 'parallel-1' },
details: [[]],
}),
})
handlers.onIterationFinish({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'iteration_finish',
data: createTrace({
node_id: 'iter-1',
execution_metadata: { parallel_id: 'parallel-1' },
}),
})
handlers.onLoopStart({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'loop_start',
data: createTrace({
node_id: 'loop-1',
execution_metadata: { parallel_id: 'parallel-2' },
details: [[]],
}),
})
handlers.onLoopNext({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'loop_next',
data: createTrace({
node_id: 'loop-1',
execution_metadata: { parallel_id: 'parallel-2' },
details: [[]],
}),
})
handlers.onLoopFinish({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'loop_finish',
data: createTrace({
node_id: 'loop-1',
execution_metadata: { parallel_id: 'parallel-2' },
}),
})
handlers.onTextChunk({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'text_chunk',
data: { text: 'Hello' },
})
handlers.onHumanInputRequired({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'human_input_required',
data: createHumanInput({ node_id: 'node-1' }),
})
handlers.onHumanInputFormFilled({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'human_input_form_filled',
data: {
node_id: 'node-1',
node_title: 'Node',
rendered_content: 'Done',
action_id: 'action-1',
action_text: 'Submit',
},
})
handlers.onHumanInputFormTimeout({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'human_input_form_timeout',
data: {
node_id: 'node-1',
node_title: 'Node',
expiration_time: 200,
},
})
handlers.onWorkflowPaused({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'workflow_paused',
data: {
outputs: {},
paused_nodes: [],
reasons: [],
workflow_run_id: 'run-1',
},
})
handlers.onWorkflowFinished({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'workflow_finished',
data: {
id: 'run-1',
workflow_id: 'wf-1',
status: WorkflowRunningStatus.Succeeded,
outputs: { answer: 'Hello' },
error: '',
elapsed_time: 0,
total_tokens: 0,
total_steps: 0,
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
},
})
})
expect(setup.currentTaskId()).toBe('task-1')
expect(setup.isStopping()).toBe(false)
expect(setup.workflowProcessData()).toEqual(expect.objectContaining({
resultText: 'Hello',
status: WorkflowRunningStatus.Succeeded,
}))
expect(sseGetMock).toHaveBeenCalledWith('/workflow/run-1/events', {}, expect.any(Object))
expect(setup.messageId()).toBe('run-1')
expect(setup.onCompleted).toHaveBeenCalledWith('{"answer":"Hello"}', 3, true)
expect(setup.setRespondingFalse).toHaveBeenCalled()
expect(setup.resetRunState).toHaveBeenCalled()
})
it('should handle timeout and workflow failures', () => {
const timeoutSetup = setupHandlers({
isTimedOut: () => true,
})
const timeoutHandlers = timeoutSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowFinished'>>
act(() => {
timeoutHandlers.onWorkflowFinished({
task_id: 'task-1',
workflow_run_id: 'run-1',
event: 'workflow_finished',
data: {
id: 'run-1',
workflow_id: 'wf-1',
status: WorkflowRunningStatus.Succeeded,
outputs: null,
error: '',
elapsed_time: 0,
total_tokens: 0,
total_steps: 0,
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
},
})
})
expect(timeoutSetup.notify).toHaveBeenCalledWith({
type: 'warning',
message: 'warningMessage.timeoutExceeded',
})
const failureSetup = setupHandlers()
const failureHandlers = failureSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onWorkflowFinished'>>
act(() => {
failureHandlers.onWorkflowStarted({
workflow_run_id: 'run-2',
task_id: 'task-2',
event: 'workflow_started',
data: { id: 'run-2', workflow_id: 'wf-2', created_at: 0 },
})
failureHandlers.onWorkflowFinished({
task_id: 'task-2',
workflow_run_id: 'run-2',
event: 'workflow_finished',
data: {
id: 'run-2',
workflow_id: 'wf-2',
status: WorkflowRunningStatus.Failed,
outputs: null,
error: 'failed',
elapsed_time: 0,
total_tokens: 0,
total_steps: 0,
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
},
})
})
expect(failureSetup.notify).toHaveBeenCalledWith({
type: 'error',
message: 'failed',
})
expect(failureSetup.onCompleted).toHaveBeenCalledWith('', 3, false)
})
it('should cover existing workflow starts, stopped runs, and non-string outputs', () => {
const setup = setupHandlers()
let existingProcess: WorkflowProcess = {
status: WorkflowRunningStatus.Paused,
tracing: [
createTrace({
node_id: 'existing-node',
status: NodeRunningStatus.Waiting,
}),
],
expand: false,
resultText: '',
}
const handlers = createWorkflowStreamHandlers({
getCompletionRes: () => '',
getWorkflowProcessData: () => existingProcess,
isTimedOut: () => false,
markEnded: vi.fn(),
notify: setup.notify,
onCompleted: setup.onCompleted,
resetRunState: setup.resetRunState,
setCompletionRes: setup.setCompletionRes,
setCurrentTaskId: setup.setCurrentTaskId,
setIsStopping: vi.fn(),
setMessageId: setup.setMessageId,
setRespondingFalse: setup.setRespondingFalse,
setWorkflowProcessData: (value) => {
existingProcess = value!
},
t: (key: string) => key,
taskId: 5,
}) as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onWorkflowFinished' | 'onTextReplace'>>
act(() => {
handlers.onWorkflowStarted({
workflow_run_id: 'run-existing',
task_id: '',
event: 'workflow_started',
data: { id: 'run-existing', workflow_id: 'wf-1', created_at: 0 },
})
handlers.onTextReplace({
task_id: 'task-existing',
workflow_run_id: 'run-existing',
event: 'text_replace',
data: { text: 'Replaced text' },
})
})
expect(existingProcess).toEqual(expect.objectContaining({
expand: true,
status: WorkflowRunningStatus.Running,
resultText: 'Replaced text',
}))
act(() => {
handlers.onWorkflowFinished({
task_id: 'task-existing',
workflow_run_id: 'run-existing',
event: 'workflow_finished',
data: {
id: 'run-existing',
workflow_id: 'wf-1',
status: WorkflowRunningStatus.Stopped,
outputs: null,
error: '',
elapsed_time: 0,
total_tokens: 0,
total_steps: 0,
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
},
})
})
expect(existingProcess.status).toBe(WorkflowRunningStatus.Stopped)
expect(existingProcess.tracing[0].status).toBe(NodeRunningStatus.Stopped)
expect(setup.onCompleted).toHaveBeenCalledWith('', 5, false)
const noOutputSetup = setupHandlers()
const noOutputHandlers = noOutputSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onWorkflowFinished' | 'onTextReplace'>>
act(() => {
noOutputHandlers.onWorkflowStarted({
workflow_run_id: 'run-no-output',
task_id: '',
event: 'workflow_started',
data: { id: 'run-no-output', workflow_id: 'wf-2', created_at: 0 },
})
noOutputHandlers.onTextReplace({
task_id: 'task-no-output',
workflow_run_id: 'run-no-output',
event: 'text_replace',
data: { text: 'Draft' },
})
noOutputHandlers.onWorkflowFinished({
task_id: 'task-no-output',
workflow_run_id: 'run-no-output',
event: 'workflow_finished',
data: {
id: 'run-no-output',
workflow_id: 'wf-2',
status: WorkflowRunningStatus.Succeeded,
outputs: null,
error: '',
elapsed_time: 0,
total_tokens: 0,
total_steps: 0,
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
},
})
})
expect(noOutputSetup.setCompletionRes).toHaveBeenCalledWith('')
const objectOutputSetup = setupHandlers()
const objectOutputHandlers = objectOutputSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onWorkflowFinished'>>
act(() => {
objectOutputHandlers.onWorkflowStarted({
workflow_run_id: 'run-object',
task_id: undefined as unknown as string,
event: 'workflow_started',
data: { id: 'run-object', workflow_id: 'wf-3', created_at: 0 },
})
objectOutputHandlers.onWorkflowFinished({
task_id: 'task-object',
workflow_run_id: 'run-object',
event: 'workflow_finished',
data: {
id: 'run-object',
workflow_id: 'wf-3',
status: WorkflowRunningStatus.Succeeded,
outputs: {
answer: 'Hello',
meta: {
mode: 'object',
},
},
error: '',
elapsed_time: 0,
total_tokens: 0,
total_steps: 0,
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
},
})
})
expect(objectOutputSetup.currentTaskId()).toBeNull()
expect(objectOutputSetup.setCompletionRes).toHaveBeenCalledWith('{"answer":"Hello","meta":{"mode":"object"}}')
expect(objectOutputSetup.workflowProcessData()).toEqual(expect.objectContaining({
status: WorkflowRunningStatus.Succeeded,
resultText: '',
}))
})
it('should serialize empty, string, and circular workflow outputs', () => {
const noOutputSetup = setupHandlers()
const noOutputHandlers = noOutputSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowFinished'>>
act(() => {
noOutputHandlers.onWorkflowFinished({
task_id: 'task-empty',
workflow_run_id: 'run-empty',
event: 'workflow_finished',
data: {
id: 'run-empty',
workflow_id: 'wf-empty',
status: WorkflowRunningStatus.Succeeded,
outputs: null,
error: '',
elapsed_time: 0,
total_tokens: 0,
total_steps: 0,
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
},
})
})
expect(noOutputSetup.setCompletionRes).toHaveBeenCalledWith('')
const stringOutputSetup = setupHandlers()
const stringOutputHandlers = stringOutputSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowFinished'>>
act(() => {
stringOutputHandlers.onWorkflowFinished({
task_id: 'task-string',
workflow_run_id: 'run-string',
event: 'workflow_finished',
data: {
id: 'run-string',
workflow_id: 'wf-string',
status: WorkflowRunningStatus.Succeeded,
outputs: 'plain text output',
error: '',
elapsed_time: 0,
total_tokens: 0,
total_steps: 0,
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
},
})
})
expect(stringOutputSetup.setCompletionRes).toHaveBeenCalledWith('plain text output')
const circularOutputSetup = setupHandlers()
const circularOutputHandlers = circularOutputSetup.handlers as Required<Pick<IOtherOptions, 'onWorkflowFinished'>>
const circularOutputs: Record<string, unknown> = {
answer: 'Hello',
}
circularOutputs.self = circularOutputs
act(() => {
circularOutputHandlers.onWorkflowFinished({
task_id: 'task-circular',
workflow_run_id: 'run-circular',
event: 'workflow_finished',
data: {
id: 'run-circular',
workflow_id: 'wf-circular',
status: WorkflowRunningStatus.Succeeded,
outputs: circularOutputs,
error: '',
elapsed_time: 0,
total_tokens: 0,
total_steps: 0,
created_at: 0,
created_by: {
id: 'user-1',
name: 'User',
email: 'user@example.com',
},
finished_at: 0,
},
})
})
expect(circularOutputSetup.setCompletionRes).toHaveBeenCalledWith('[object Object]')
})
})

View File

@ -0,0 +1,200 @@
import type { FeedbackType } from '@/app/components/base/chat/chat/type'
import { act, renderHook, waitFor } from '@testing-library/react'
import { AppSourceType } from '@/service/share'
import { useResultRunState } from '../use-result-run-state'
const {
stopChatMessageRespondingMock,
stopWorkflowMessageMock,
updateFeedbackMock,
} = vi.hoisted(() => ({
stopChatMessageRespondingMock: vi.fn(),
stopWorkflowMessageMock: vi.fn(),
updateFeedbackMock: vi.fn(),
}))
vi.mock('@/service/share', async () => {
const actual = await vi.importActual<typeof import('@/service/share')>('@/service/share')
return {
...actual,
stopChatMessageResponding: (...args: Parameters<typeof actual.stopChatMessageResponding>) => stopChatMessageRespondingMock(...args),
stopWorkflowMessage: (...args: Parameters<typeof actual.stopWorkflowMessage>) => stopWorkflowMessageMock(...args),
updateFeedback: (...args: Parameters<typeof actual.updateFeedback>) => updateFeedbackMock(...args),
}
})
describe('useResultRunState', () => {
beforeEach(() => {
vi.clearAllMocks()
stopChatMessageRespondingMock.mockResolvedValue(undefined)
stopWorkflowMessageMock.mockResolvedValue(undefined)
updateFeedbackMock.mockResolvedValue(undefined)
})
it('should expose run control and stop completion requests', async () => {
const notify = vi.fn()
const onRunControlChange = vi.fn()
const { result } = renderHook(() => useResultRunState({
appId: 'app-1',
appSourceType: AppSourceType.webApp,
controlStopResponding: 0,
isWorkflow: false,
notify,
onRunControlChange,
}))
const abort = vi.fn()
act(() => {
result.current.abortControllerRef.current = { abort } as unknown as AbortController
result.current.setCurrentTaskId('task-1')
result.current.setRespondingTrue()
})
await waitFor(() => {
expect(onRunControlChange).toHaveBeenLastCalledWith(expect.objectContaining({
isStopping: false,
}))
})
await act(async () => {
await result.current.handleStop()
})
expect(stopChatMessageRespondingMock).toHaveBeenCalledWith('app-1', 'task-1', AppSourceType.webApp, 'app-1')
expect(abort).toHaveBeenCalledTimes(1)
})
it('should update feedback and react to external stop control', async () => {
const notify = vi.fn()
const onRunControlChange = vi.fn()
const { result, rerender } = renderHook(({ controlStopResponding }) => useResultRunState({
appId: 'app-2',
appSourceType: AppSourceType.installedApp,
controlStopResponding,
isWorkflow: true,
notify,
onRunControlChange,
}), {
initialProps: { controlStopResponding: 0 },
})
const abort = vi.fn()
act(() => {
result.current.abortControllerRef.current = { abort } as unknown as AbortController
result.current.setMessageId('message-1')
})
await act(async () => {
await result.current.handleFeedback({
rating: 'like',
} satisfies FeedbackType)
})
expect(updateFeedbackMock).toHaveBeenCalledWith({
url: '/messages/message-1/feedbacks',
body: {
rating: 'like',
content: undefined,
},
}, AppSourceType.installedApp, 'app-2')
expect(result.current.feedback).toEqual({
rating: 'like',
})
act(() => {
result.current.setCurrentTaskId('task-2')
result.current.setRespondingTrue()
})
rerender({ controlStopResponding: 1 })
await waitFor(() => {
expect(abort).toHaveBeenCalled()
expect(result.current.currentTaskId).toBeNull()
expect(onRunControlChange).toHaveBeenLastCalledWith(null)
})
})
it('should stop workflow requests through the workflow stop API', async () => {
const notify = vi.fn()
const { result } = renderHook(() => useResultRunState({
appId: 'app-3',
appSourceType: AppSourceType.installedApp,
controlStopResponding: 0,
isWorkflow: true,
notify,
}))
act(() => {
result.current.setCurrentTaskId('task-3')
})
await act(async () => {
await result.current.handleStop()
})
expect(stopWorkflowMessageMock).toHaveBeenCalledWith('app-3', 'task-3', AppSourceType.installedApp, 'app-3')
})
it('should ignore invalid stops and report non-Error failures', async () => {
const notify = vi.fn()
stopChatMessageRespondingMock.mockRejectedValueOnce('stop failed')
const { result } = renderHook(() => useResultRunState({
appSourceType: AppSourceType.webApp,
controlStopResponding: 0,
isWorkflow: false,
notify,
}))
await act(async () => {
await result.current.handleStop()
})
expect(stopChatMessageRespondingMock).not.toHaveBeenCalled()
act(() => {
result.current.setCurrentTaskId('task-4')
result.current.setIsStopping(prev => !prev)
result.current.setIsStopping(prev => !prev)
})
await act(async () => {
await result.current.handleStop()
})
expect(stopChatMessageRespondingMock).toHaveBeenCalledWith(undefined, 'task-4', AppSourceType.webApp, '')
expect(notify).toHaveBeenCalledWith({
type: 'error',
message: 'stop failed',
})
expect(result.current.isStopping).toBe(false)
})
it('should report Error instances from workflow stop failures without an app id fallback', async () => {
const notify = vi.fn()
stopWorkflowMessageMock.mockRejectedValueOnce(new Error('workflow stop failed'))
const { result } = renderHook(() => useResultRunState({
appSourceType: AppSourceType.installedApp,
controlStopResponding: 0,
isWorkflow: true,
notify,
}))
act(() => {
result.current.setCurrentTaskId('task-5')
})
await act(async () => {
await result.current.handleStop()
})
expect(stopWorkflowMessageMock).toHaveBeenCalledWith(undefined, 'task-5', AppSourceType.installedApp, '')
expect(notify).toHaveBeenCalledWith({
type: 'error',
message: 'workflow stop failed',
})
})
})

View File

@ -0,0 +1,510 @@
import type { ResultInputValue } from '../../result-request'
import type { ResultRunStateController } from '../use-result-run-state'
import type { PromptConfig } from '@/models/debug'
import type { AppSourceType } from '@/service/share'
import type { VisionSettings } from '@/types/app'
import { act, renderHook, waitFor } from '@testing-library/react'
import { AppSourceType as AppSourceTypeEnum } from '@/service/share'
import { Resolution, TransferMethod } from '@/types/app'
import { useResultSender } from '../use-result-sender'
const {
buildResultRequestDataMock,
createWorkflowStreamHandlersMock,
sendCompletionMessageMock,
sendWorkflowMessageMock,
sleepMock,
validateResultRequestMock,
} = vi.hoisted(() => ({
buildResultRequestDataMock: vi.fn(),
createWorkflowStreamHandlersMock: vi.fn(),
sendCompletionMessageMock: vi.fn(),
sendWorkflowMessageMock: vi.fn(),
sleepMock: vi.fn(),
validateResultRequestMock: vi.fn(),
}))
vi.mock('@/service/share', async () => {
const actual = await vi.importActual<typeof import('@/service/share')>('@/service/share')
return {
...actual,
sendCompletionMessage: (...args: Parameters<typeof actual.sendCompletionMessage>) => sendCompletionMessageMock(...args),
sendWorkflowMessage: (...args: Parameters<typeof actual.sendWorkflowMessage>) => sendWorkflowMessageMock(...args),
}
})
vi.mock('@/utils', async () => {
const actual = await vi.importActual<typeof import('@/utils')>('@/utils')
return {
...actual,
sleep: (...args: Parameters<typeof actual.sleep>) => sleepMock(...args),
}
})
vi.mock('../../result-request', () => ({
buildResultRequestData: (...args: unknown[]) => buildResultRequestDataMock(...args),
validateResultRequest: (...args: unknown[]) => validateResultRequestMock(...args),
}))
vi.mock('../../workflow-stream-handlers', () => ({
createWorkflowStreamHandlers: (...args: unknown[]) => createWorkflowStreamHandlersMock(...args),
}))
type RunStateHarness = {
state: {
completionRes: string
currentTaskId: string | null
messageId: string | null
workflowProcessData: ResultRunStateController['workflowProcessData']
}
runState: ResultRunStateController
}
type CompletionHandlers = {
getAbortController: (abortController: AbortController) => void
onCompleted: () => void
onData: (chunk: string, isFirstMessage: boolean, info: { messageId: string, taskId?: string }) => void
onError: () => void
onMessageReplace: (messageReplace: { answer: string }) => void
}
const createRunStateHarness = (): RunStateHarness => {
const state: RunStateHarness['state'] = {
completionRes: '',
currentTaskId: null,
messageId: null,
workflowProcessData: undefined,
}
const runState: ResultRunStateController = {
abortControllerRef: { current: null },
clearMoreLikeThis: vi.fn(),
completionRes: '',
controlClearMoreLikeThis: 0,
currentTaskId: null,
feedback: { rating: null },
getCompletionRes: vi.fn(() => state.completionRes),
getWorkflowProcessData: vi.fn(() => state.workflowProcessData),
handleFeedback: vi.fn(),
handleStop: vi.fn(),
isResponding: false,
isStopping: false,
messageId: null,
prepareForNewRun: vi.fn(() => {
state.completionRes = ''
state.currentTaskId = null
state.messageId = null
state.workflowProcessData = undefined
runState.completionRes = ''
runState.currentTaskId = null
runState.messageId = null
runState.workflowProcessData = undefined
}),
resetRunState: vi.fn(() => {
state.currentTaskId = null
runState.currentTaskId = null
runState.isStopping = false
}),
setCompletionRes: vi.fn((value: string) => {
state.completionRes = value
runState.completionRes = value
}),
setCurrentTaskId: vi.fn((value) => {
state.currentTaskId = typeof value === 'function' ? value(state.currentTaskId) : value
runState.currentTaskId = state.currentTaskId
}),
setIsStopping: vi.fn((value) => {
runState.isStopping = typeof value === 'function' ? value(runState.isStopping) : value
}),
setMessageId: vi.fn((value) => {
state.messageId = typeof value === 'function' ? value(state.messageId) : value
runState.messageId = state.messageId
}),
setRespondingFalse: vi.fn(() => {
runState.isResponding = false
}),
setRespondingTrue: vi.fn(() => {
runState.isResponding = true
}),
setWorkflowProcessData: vi.fn((value) => {
state.workflowProcessData = value
runState.workflowProcessData = value
}),
workflowProcessData: undefined,
}
return {
state,
runState,
}
}
const promptConfig: PromptConfig = {
prompt_template: 'template',
prompt_variables: [
{ key: 'name', name: 'Name', type: 'string', required: true },
],
}
const visionConfig: VisionSettings = {
enabled: false,
number_limits: 2,
detail: Resolution.low,
transfer_methods: [TransferMethod.local_file],
}
type RenderSenderOptions = {
appSourceType?: AppSourceType
controlRetry?: number
controlSend?: number
inputs?: Record<string, ResultInputValue>
isPC?: boolean
isWorkflow?: boolean
runState?: ResultRunStateController
taskId?: number
}
const renderSender = ({
appSourceType = AppSourceTypeEnum.webApp,
controlRetry = 0,
controlSend = 0,
inputs = { name: 'Alice' },
isPC = true,
isWorkflow = false,
runState,
taskId,
}: RenderSenderOptions = {}) => {
const notify = vi.fn()
const onCompleted = vi.fn()
const onRunStart = vi.fn()
const onShowRes = vi.fn()
const hook = renderHook((props: { controlRetry: number, controlSend: number }) => useResultSender({
appId: 'app-1',
appSourceType,
completionFiles: [],
controlRetry: props.controlRetry,
controlSend: props.controlSend,
inputs,
isCallBatchAPI: false,
isPC,
isWorkflow,
notify,
onCompleted,
onRunStart,
onShowRes,
promptConfig,
runState: runState || createRunStateHarness().runState,
t: (key: string) => key,
taskId,
visionConfig,
}), {
initialProps: {
controlRetry,
controlSend,
},
})
return {
...hook,
notify,
onCompleted,
onRunStart,
onShowRes,
}
}
describe('useResultSender', () => {
beforeEach(() => {
vi.clearAllMocks()
validateResultRequestMock.mockReturnValue({ canSend: true })
buildResultRequestDataMock.mockReturnValue({ inputs: { name: 'Alice' } })
createWorkflowStreamHandlersMock.mockReturnValue({ onWorkflowFinished: vi.fn() })
sendCompletionMessageMock.mockResolvedValue(undefined)
sendWorkflowMessageMock.mockResolvedValue(undefined)
sleepMock.mockImplementation(() => new Promise<void>(() => {}))
})
it('should reject sends while a response is already in progress', async () => {
const { runState } = createRunStateHarness()
runState.isResponding = true
const { result, notify } = renderSender({ runState })
await act(async () => {
expect(await result.current.handleSend()).toBe(false)
})
expect(notify).toHaveBeenCalledWith({
type: 'info',
message: 'errorMessage.waitForResponse',
})
expect(validateResultRequestMock).not.toHaveBeenCalled()
expect(sendCompletionMessageMock).not.toHaveBeenCalled()
})
it('should surface validation failures without building request payloads', async () => {
const { runState } = createRunStateHarness()
validateResultRequestMock.mockReturnValue({
canSend: false,
notification: {
type: 'error',
message: 'invalid',
},
})
const { result, notify } = renderSender({ runState })
await act(async () => {
expect(await result.current.handleSend()).toBe(false)
})
expect(notify).toHaveBeenCalledWith({
type: 'error',
message: 'invalid',
})
expect(buildResultRequestDataMock).not.toHaveBeenCalled()
expect(sendCompletionMessageMock).not.toHaveBeenCalled()
})
it('should send completion requests when controlSend changes and process callbacks', async () => {
const harness = createRunStateHarness()
let completionHandlers: CompletionHandlers | undefined
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
completionHandlers = handlers as CompletionHandlers
})
const { rerender, onCompleted, onRunStart, onShowRes } = renderSender({
controlSend: 0,
isPC: false,
runState: harness.runState,
taskId: 7,
})
rerender({
controlRetry: 0,
controlSend: 1,
})
expect(validateResultRequestMock).toHaveBeenCalledWith(expect.objectContaining({
inputs: { name: 'Alice' },
isCallBatchAPI: false,
}))
expect(buildResultRequestDataMock).toHaveBeenCalled()
expect(harness.runState.prepareForNewRun).toHaveBeenCalledTimes(1)
expect(harness.runState.setRespondingTrue).toHaveBeenCalledTimes(1)
expect(harness.runState.clearMoreLikeThis).toHaveBeenCalledTimes(1)
expect(onShowRes).toHaveBeenCalledTimes(1)
expect(onRunStart).toHaveBeenCalledTimes(1)
expect(sendCompletionMessageMock).toHaveBeenCalledWith(
{ inputs: { name: 'Alice' } },
expect.objectContaining({
onCompleted: expect.any(Function),
onData: expect.any(Function),
}),
AppSourceTypeEnum.webApp,
'app-1',
)
const abortController = {} as AbortController
expect(completionHandlers).toBeDefined()
completionHandlers!.getAbortController(abortController)
expect(harness.runState.abortControllerRef.current).toBe(abortController)
await act(async () => {
completionHandlers!.onData('Hello', false, {
messageId: 'message-1',
taskId: 'task-1',
})
})
expect(harness.runState.setCurrentTaskId).toHaveBeenCalled()
expect(harness.runState.currentTaskId).toBe('task-1')
await act(async () => {
completionHandlers!.onMessageReplace({ answer: 'Replaced' })
completionHandlers!.onCompleted()
})
expect(harness.runState.setCompletionRes).toHaveBeenLastCalledWith('Replaced')
expect(harness.runState.setRespondingFalse).toHaveBeenCalled()
expect(harness.runState.resetRunState).toHaveBeenCalled()
expect(harness.runState.setMessageId).toHaveBeenCalledWith('message-1')
expect(onCompleted).toHaveBeenCalledWith('Replaced', 7, true)
})
it('should trigger workflow sends on retry and report workflow request failures', async () => {
const harness = createRunStateHarness()
sendWorkflowMessageMock.mockRejectedValue(new Error('workflow failed'))
const { rerender, notify } = renderSender({
controlRetry: 0,
isWorkflow: true,
runState: harness.runState,
})
rerender({
controlRetry: 2,
controlSend: 0,
})
await waitFor(() => {
expect(createWorkflowStreamHandlersMock).toHaveBeenCalledWith(expect.objectContaining({
getCompletionRes: harness.runState.getCompletionRes,
resetRunState: harness.runState.resetRunState,
setWorkflowProcessData: harness.runState.setWorkflowProcessData,
}))
expect(sendWorkflowMessageMock).toHaveBeenCalledWith(
{ inputs: { name: 'Alice' } },
expect.any(Object),
AppSourceTypeEnum.webApp,
'app-1',
)
})
await waitFor(() => {
expect(harness.runState.setRespondingFalse).toHaveBeenCalled()
expect(harness.runState.resetRunState).toHaveBeenCalled()
expect(notify).toHaveBeenCalledWith({
type: 'error',
message: 'workflow failed',
})
})
expect(harness.runState.clearMoreLikeThis).not.toHaveBeenCalled()
})
it('should stringify non-Error workflow failures', async () => {
const harness = createRunStateHarness()
sendWorkflowMessageMock.mockRejectedValue('workflow failed')
const { result, notify } = renderSender({
isWorkflow: true,
runState: harness.runState,
})
await act(async () => {
await result.current.handleSend()
})
await waitFor(() => {
expect(notify).toHaveBeenCalledWith({
type: 'error',
message: 'workflow failed',
})
})
})
it('should timeout unfinished completion requests', async () => {
const harness = createRunStateHarness()
sleepMock.mockResolvedValue(undefined)
const { result, onCompleted } = renderSender({
runState: harness.runState,
taskId: 9,
})
await act(async () => {
expect(await result.current.handleSend()).toBe(true)
})
await waitFor(() => {
expect(harness.runState.setRespondingFalse).toHaveBeenCalled()
expect(harness.runState.resetRunState).toHaveBeenCalled()
expect(onCompleted).toHaveBeenCalledWith('', 9, false)
})
})
it('should ignore empty task ids and surface timeout warnings from stream callbacks', async () => {
const harness = createRunStateHarness()
let completionHandlers: CompletionHandlers | undefined
sleepMock.mockResolvedValue(undefined)
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
completionHandlers = handlers as CompletionHandlers
})
const { result, notify, onCompleted } = renderSender({
runState: harness.runState,
taskId: 11,
})
await act(async () => {
await result.current.handleSend()
})
await act(async () => {
completionHandlers!.onData('Hello', false, {
messageId: 'message-2',
taskId: ' ',
})
completionHandlers!.onCompleted()
completionHandlers!.onError()
})
expect(harness.runState.currentTaskId).toBeNull()
expect(notify).toHaveBeenNthCalledWith(1, {
type: 'warning',
message: 'warningMessage.timeoutExceeded',
})
expect(notify).toHaveBeenNthCalledWith(2, {
type: 'warning',
message: 'warningMessage.timeoutExceeded',
})
expect(onCompleted).toHaveBeenCalledWith('', 11, false)
})
it('should avoid timeout fallback after a completion response has already ended', async () => {
const harness = createRunStateHarness()
let resolveSleep!: () => void
let completionHandlers: CompletionHandlers | undefined
sleepMock.mockImplementation(() => new Promise<void>((resolve) => {
resolveSleep = resolve
}))
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
completionHandlers = handlers as CompletionHandlers
})
const { result, onCompleted } = renderSender({
runState: harness.runState,
taskId: 12,
})
await act(async () => {
await result.current.handleSend()
})
await act(async () => {
harness.runState.setCompletionRes('Done')
completionHandlers!.onCompleted()
resolveSleep()
await Promise.resolve()
})
expect(onCompleted).toHaveBeenCalledWith('Done', 12, true)
expect(onCompleted).toHaveBeenCalledTimes(1)
})
it('should handle non-timeout stream errors as failed completions', async () => {
const harness = createRunStateHarness()
let completionHandlers: CompletionHandlers | undefined
sendCompletionMessageMock.mockImplementation(async (_data, handlers) => {
completionHandlers = handlers as CompletionHandlers
})
const { result, onCompleted } = renderSender({
runState: harness.runState,
taskId: 13,
})
await act(async () => {
await result.current.handleSend()
completionHandlers!.onError()
})
expect(harness.runState.setRespondingFalse).toHaveBeenCalled()
expect(harness.runState.resetRunState).toHaveBeenCalled()
expect(onCompleted).toHaveBeenCalledWith('', 13, false)
})
})

View File

@ -0,0 +1,237 @@
import type { Dispatch, MutableRefObject, SetStateAction } from 'react'
import type { FeedbackType } from '@/app/components/base/chat/chat/type'
import type { WorkflowProcess } from '@/app/components/base/chat/types'
import type { AppSourceType } from '@/service/share'
import { useBoolean } from 'ahooks'
import { useCallback, useEffect, useReducer, useRef, useState } from 'react'
import {
stopChatMessageResponding,
stopWorkflowMessage,
updateFeedback,
} from '@/service/share'
type Notify = (payload: { type: 'error', message: string }) => void
type RunControlState = {
currentTaskId: string | null
isStopping: boolean
}
type RunControlAction
= | { type: 'reset' }
| { type: 'setCurrentTaskId', value: SetStateAction<string | null> }
| { type: 'setIsStopping', value: SetStateAction<boolean> }
type UseResultRunStateOptions = {
appId?: string
appSourceType: AppSourceType
controlStopResponding?: number
isWorkflow: boolean
notify: Notify
onRunControlChange?: (control: { onStop: () => Promise<void> | void, isStopping: boolean } | null) => void
}
export type ResultRunStateController = {
abortControllerRef: MutableRefObject<AbortController | null>
clearMoreLikeThis: () => void
completionRes: string
controlClearMoreLikeThis: number
currentTaskId: string | null
feedback: FeedbackType
getCompletionRes: () => string
getWorkflowProcessData: () => WorkflowProcess | undefined
handleFeedback: (feedback: FeedbackType) => Promise<void>
handleStop: () => Promise<void>
isResponding: boolean
isStopping: boolean
messageId: string | null
prepareForNewRun: () => void
resetRunState: () => void
setCompletionRes: (res: string) => void
setCurrentTaskId: Dispatch<SetStateAction<string | null>>
setIsStopping: Dispatch<SetStateAction<boolean>>
setMessageId: Dispatch<SetStateAction<string | null>>
setRespondingFalse: () => void
setRespondingTrue: () => void
setWorkflowProcessData: (data: WorkflowProcess | undefined) => void
workflowProcessData: WorkflowProcess | undefined
}
const runControlReducer = (state: RunControlState, action: RunControlAction): RunControlState => {
switch (action.type) {
case 'reset':
return {
currentTaskId: null,
isStopping: false,
}
case 'setCurrentTaskId':
return {
...state,
currentTaskId: typeof action.value === 'function' ? action.value(state.currentTaskId) : action.value,
}
case 'setIsStopping':
return {
...state,
isStopping: typeof action.value === 'function' ? action.value(state.isStopping) : action.value,
}
}
}
export const useResultRunState = ({
appId,
appSourceType,
controlStopResponding,
isWorkflow,
notify,
onRunControlChange,
}: UseResultRunStateOptions): ResultRunStateController => {
const [isResponding, { setTrue: setRespondingTrue, setFalse: setRespondingFalse }] = useBoolean(false)
const [completionResState, setCompletionResState] = useState<string>('')
const completionResRef = useRef<string>('')
const [workflowProcessDataState, setWorkflowProcessDataState] = useState<WorkflowProcess>()
const workflowProcessDataRef = useRef<WorkflowProcess | undefined>(undefined)
const [messageId, setMessageId] = useState<string | null>(null)
const [feedback, setFeedback] = useState<FeedbackType>({
rating: null,
})
const [controlClearMoreLikeThis, setControlClearMoreLikeThis] = useState(0)
const abortControllerRef = useRef<AbortController | null>(null)
const [{ currentTaskId, isStopping }, dispatchRunControl] = useReducer(runControlReducer, {
currentTaskId: null,
isStopping: false,
})
const setCurrentTaskId = useCallback<Dispatch<SetStateAction<string | null>>>((value) => {
dispatchRunControl({
type: 'setCurrentTaskId',
value,
})
}, [])
const setIsStopping = useCallback<Dispatch<SetStateAction<boolean>>>((value) => {
dispatchRunControl({
type: 'setIsStopping',
value,
})
}, [])
const setCompletionRes = useCallback((res: string) => {
completionResRef.current = res
setCompletionResState(res)
}, [])
const getCompletionRes = useCallback(() => completionResRef.current, [])
const setWorkflowProcessData = useCallback((data: WorkflowProcess | undefined) => {
workflowProcessDataRef.current = data
setWorkflowProcessDataState(data)
}, [])
const getWorkflowProcessData = useCallback(() => workflowProcessDataRef.current, [])
const resetRunState = useCallback(() => {
dispatchRunControl({ type: 'reset' })
abortControllerRef.current = null
onRunControlChange?.(null)
}, [onRunControlChange])
const prepareForNewRun = useCallback(() => {
setMessageId(null)
setFeedback({ rating: null })
setCompletionRes('')
setWorkflowProcessData(undefined)
resetRunState()
}, [resetRunState, setCompletionRes, setWorkflowProcessData])
const handleFeedback = useCallback(async (nextFeedback: FeedbackType) => {
await updateFeedback({
url: `/messages/${messageId}/feedbacks`,
body: {
rating: nextFeedback.rating,
content: nextFeedback.content,
},
}, appSourceType, appId)
setFeedback(nextFeedback)
}, [appId, appSourceType, messageId])
const handleStop = useCallback(async () => {
if (!currentTaskId || isStopping)
return
setIsStopping(true)
try {
if (isWorkflow)
await stopWorkflowMessage(appId!, currentTaskId, appSourceType, appId || '')
else
await stopChatMessageResponding(appId!, currentTaskId, appSourceType, appId || '')
abortControllerRef.current?.abort()
}
catch (error) {
const message = error instanceof Error ? error.message : String(error)
notify({ type: 'error', message })
}
finally {
setIsStopping(false)
}
}, [appId, appSourceType, currentTaskId, isStopping, isWorkflow, notify, setIsStopping])
const clearMoreLikeThis = useCallback(() => {
setControlClearMoreLikeThis(Date.now())
}, [])
useEffect(() => {
const abortCurrentRequest = () => {
abortControllerRef.current?.abort()
}
if (controlStopResponding) {
abortCurrentRequest()
setRespondingFalse()
resetRunState()
}
return abortCurrentRequest
}, [controlStopResponding, resetRunState, setRespondingFalse])
useEffect(() => {
if (!onRunControlChange)
return
if (isResponding && currentTaskId) {
onRunControlChange({
onStop: handleStop,
isStopping,
})
return
}
onRunControlChange(null)
}, [currentTaskId, handleStop, isResponding, isStopping, onRunControlChange])
return {
abortControllerRef,
clearMoreLikeThis,
completionRes: completionResState,
controlClearMoreLikeThis,
currentTaskId,
feedback,
getCompletionRes,
getWorkflowProcessData,
handleFeedback,
handleStop,
isResponding,
isStopping,
messageId,
prepareForNewRun,
resetRunState,
setCompletionRes,
setCurrentTaskId,
setIsStopping,
setMessageId,
setRespondingFalse,
setRespondingTrue,
setWorkflowProcessData,
workflowProcessData: workflowProcessDataState,
}
}

View File

@ -0,0 +1,230 @@
import type { ResultInputValue } from '../result-request'
import type { ResultRunStateController } from './use-result-run-state'
import type { PromptConfig } from '@/models/debug'
import type { AppSourceType } from '@/service/share'
import type { VisionFile, VisionSettings } from '@/types/app'
import { useCallback, useEffect, useRef } from 'react'
import { TEXT_GENERATION_TIMEOUT_MS } from '@/config'
import {
sendCompletionMessage,
sendWorkflowMessage,
} from '@/service/share'
import { sleep } from '@/utils'
import { buildResultRequestData, validateResultRequest } from '../result-request'
import { createWorkflowStreamHandlers } from '../workflow-stream-handlers'
type Notify = (payload: { type: 'error' | 'info' | 'warning', message: string }) => void
type Translate = (key: string, options?: Record<string, unknown>) => string
type UseResultSenderOptions = {
appId?: string
appSourceType: AppSourceType
completionFiles: VisionFile[]
controlRetry?: number
controlSend?: number
inputs: Record<string, ResultInputValue>
isCallBatchAPI: boolean
isPC: boolean
isWorkflow: boolean
notify: Notify
onCompleted: (completionRes: string, taskId?: number, success?: boolean) => void
onRunStart: () => void
onShowRes: () => void
promptConfig: PromptConfig | null
runState: ResultRunStateController
t: Translate
taskId?: number
visionConfig: VisionSettings
}
const logRequestError = (notify: Notify, error: unknown) => {
const message = error instanceof Error ? error.message : String(error)
notify({ type: 'error', message })
}
export const useResultSender = ({
appId,
appSourceType,
completionFiles,
controlRetry,
controlSend,
inputs,
isCallBatchAPI,
isPC,
isWorkflow,
notify,
onCompleted,
onRunStart,
onShowRes,
promptConfig,
runState,
t,
taskId,
visionConfig,
}: UseResultSenderOptions) => {
const { clearMoreLikeThis } = runState
const handleSend = useCallback(async () => {
if (runState.isResponding) {
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
return false
}
const validation = validateResultRequest({
completionFiles,
inputs,
isCallBatchAPI,
promptConfig,
t,
})
if (!validation.canSend) {
notify(validation.notification!)
return false
}
const data = buildResultRequestData({
completionFiles,
inputs,
promptConfig,
visionConfig,
})
runState.prepareForNewRun()
if (!isPC) {
onShowRes()
onRunStart()
}
runState.setRespondingTrue()
let isEnd = false
let isTimeout = false
let completionChunks: string[] = []
let tempMessageId = ''
void (async () => {
await sleep(TEXT_GENERATION_TIMEOUT_MS)
if (!isEnd) {
runState.setRespondingFalse()
onCompleted(runState.getCompletionRes(), taskId, false)
runState.resetRunState()
isTimeout = true
}
})()
if (isWorkflow) {
const otherOptions = createWorkflowStreamHandlers({
getCompletionRes: runState.getCompletionRes,
getWorkflowProcessData: runState.getWorkflowProcessData,
isTimedOut: () => isTimeout,
markEnded: () => {
isEnd = true
},
notify,
onCompleted,
resetRunState: runState.resetRunState,
setCompletionRes: runState.setCompletionRes,
setCurrentTaskId: runState.setCurrentTaskId,
setIsStopping: runState.setIsStopping,
setMessageId: runState.setMessageId,
setRespondingFalse: runState.setRespondingFalse,
setWorkflowProcessData: runState.setWorkflowProcessData,
t,
taskId,
})
void sendWorkflowMessage(data, otherOptions, appSourceType, appId).catch((error) => {
runState.setRespondingFalse()
runState.resetRunState()
logRequestError(notify, error)
})
return true
}
void sendCompletionMessage(data, {
onData: (chunk, _isFirstMessage, { messageId, taskId: nextTaskId }) => {
tempMessageId = messageId
if (nextTaskId && nextTaskId.trim() !== '')
runState.setCurrentTaskId(prev => prev ?? nextTaskId)
completionChunks.push(chunk)
runState.setCompletionRes(completionChunks.join(''))
},
onCompleted: () => {
if (isTimeout) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
runState.setRespondingFalse()
runState.resetRunState()
runState.setMessageId(tempMessageId)
onCompleted(runState.getCompletionRes(), taskId, true)
isEnd = true
},
onMessageReplace: (messageReplace) => {
completionChunks = [messageReplace.answer]
runState.setCompletionRes(completionChunks.join(''))
},
onError: () => {
if (isTimeout) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
runState.setRespondingFalse()
runState.resetRunState()
onCompleted(runState.getCompletionRes(), taskId, false)
isEnd = true
},
getAbortController: (abortController) => {
runState.abortControllerRef.current = abortController
},
}, appSourceType, appId)
return true
}, [
appId,
appSourceType,
completionFiles,
inputs,
isCallBatchAPI,
isPC,
isWorkflow,
notify,
onCompleted,
onRunStart,
onShowRes,
promptConfig,
runState,
t,
taskId,
visionConfig,
])
const handleSendRef = useRef(handleSend)
useEffect(() => {
handleSendRef.current = handleSend
}, [handleSend])
useEffect(() => {
if (!controlSend)
return
void handleSendRef.current()
clearMoreLikeThis()
}, [clearMoreLikeThis, controlSend])
useEffect(() => {
if (!controlRetry)
return
void handleSendRef.current()
}, [controlRetry])
return {
handleSend,
}
}

View File

@ -1,46 +1,18 @@
'use client'
import type { FC } from 'react'
import type { FeedbackType } from '@/app/components/base/chat/chat/type'
import type { WorkflowProcess } from '@/app/components/base/chat/types'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { PromptConfig } from '@/models/debug'
import type { SiteInfo } from '@/models/share'
import type {
IOtherOptions,
} from '@/service/base'
import type { AppSourceType } from '@/service/share'
import type { VisionFile, VisionSettings } from '@/types/app'
import { RiLoader2Line } from '@remixicon/react'
import { useBoolean } from 'ahooks'
import { t } from 'i18next'
import { produce } from 'immer'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import TextGenerationRes from '@/app/components/app/text-generate/item'
import Button from '@/app/components/base/button'
import {
getFilesInLogs,
getProcessedFiles,
} from '@/app/components/base/file-uploader/utils'
import { StopCircle } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices'
import Loading from '@/app/components/base/loading'
import Toast from '@/app/components/base/toast'
import NoData from '@/app/components/share/text-generation/no-data'
import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types'
import { TEXT_GENERATION_TIMEOUT_MS } from '@/config'
import {
sseGet,
} from '@/service/base'
import {
AppSourceType,
sendCompletionMessage,
sendWorkflowMessage,
stopChatMessageResponding,
stopWorkflowMessage,
updateFeedback,
} from '@/service/share'
import { TransferMethod } from '@/types/app'
import { sleep } from '@/utils'
import { formatBooleanInputs } from '@/utils/model-config'
import { useResultRunState } from './hooks/use-result-run-state'
import { useResultSender } from './hooks/use-result-sender'
export type IResultProps = {
isWorkflow: boolean
@ -95,554 +67,52 @@ const Result: FC<IResultProps> = ({
onRunControlChange,
hideInlineStopButton = false,
}) => {
const [isResponding, { setTrue: setRespondingTrue, setFalse: setRespondingFalse }] = useBoolean(false)
const [completionRes, doSetCompletionRes] = useState<string>('')
const completionResRef = useRef<string>('')
const setCompletionRes = (res: string) => {
completionResRef.current = res
doSetCompletionRes(res)
}
const getCompletionRes = () => completionResRef.current
const [workflowProcessData, doSetWorkflowProcessData] = useState<WorkflowProcess>()
const workflowProcessDataRef = useRef<WorkflowProcess | undefined>(undefined)
const setWorkflowProcessData = useCallback((data: WorkflowProcess | undefined) => {
workflowProcessDataRef.current = data
doSetWorkflowProcessData(data)
}, [])
const getWorkflowProcessData = () => workflowProcessDataRef.current
const [currentTaskId, setCurrentTaskId] = useState<string | null>(null)
const [isStopping, setIsStopping] = useState(false)
const abortControllerRef = useRef<AbortController | null>(null)
const resetRunState = useCallback(() => {
setCurrentTaskId(null)
setIsStopping(false)
abortControllerRef.current = null
onRunControlChange?.(null)
}, [onRunControlChange])
useEffect(() => {
const abortCurrentRequest = () => {
abortControllerRef.current?.abort()
}
if (controlStopResponding) {
abortCurrentRequest()
setRespondingFalse()
resetRunState()
}
return abortCurrentRequest
}, [controlStopResponding, resetRunState, setRespondingFalse])
const { notify } = Toast
const isNoData = !completionRes
const [messageId, setMessageId] = useState<string | null>(null)
const [feedback, setFeedback] = useState<FeedbackType>({
rating: null,
const runState = useResultRunState({
appId,
appSourceType,
controlStopResponding,
isWorkflow,
notify,
onRunControlChange,
})
const handleFeedback = async (feedback: FeedbackType) => {
await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating, content: feedback.content } }, appSourceType, appId)
setFeedback(feedback)
}
const { handleSend } = useResultSender({
appId,
appSourceType,
completionFiles,
controlRetry,
controlSend,
inputs,
isCallBatchAPI,
isPC,
isWorkflow,
notify,
onCompleted,
onRunStart,
onShowRes,
promptConfig,
runState,
t,
taskId,
visionConfig,
})
const logError = (message: string) => {
notify({ type: 'error', message })
}
const handleStop = useCallback(async () => {
if (!currentTaskId || isStopping)
return
setIsStopping(true)
try {
if (isWorkflow)
await stopWorkflowMessage(appId!, currentTaskId, appSourceType, appId || '')
else
await stopChatMessageResponding(appId!, currentTaskId, appSourceType, appId || '')
abortControllerRef.current?.abort()
}
catch (error) {
const message = error instanceof Error ? error.message : String(error)
notify({ type: 'error', message })
}
finally {
setIsStopping(false)
}
}, [appId, currentTaskId, appSourceType, isStopping, isWorkflow, notify])
useEffect(() => {
if (!onRunControlChange)
return
if (isResponding && currentTaskId) {
onRunControlChange({
onStop: handleStop,
isStopping,
})
}
else {
onRunControlChange(null)
}
}, [currentTaskId, handleStop, isResponding, isStopping, onRunControlChange])
const checkCanSend = () => {
// batch will check outer
if (isCallBatchAPI)
return true
const prompt_variables = promptConfig?.prompt_variables
if (!prompt_variables || prompt_variables?.length === 0) {
if (completionFiles.some(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) {
notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) })
return false
}
return true
}
let hasEmptyInput = ''
const requiredVars = prompt_variables?.filter(({ key, name, required, type }) => {
if (type === 'boolean' || type === 'checkbox')
return false // boolean/checkbox input is not required
const res = (!key || !key.trim()) || (!name || !name.trim()) || (required || required === undefined || required === null)
return res
}) || [] // compatible with old version
requiredVars.forEach(({ key, name }) => {
if (hasEmptyInput)
return
if (!inputs[key])
hasEmptyInput = name
})
if (hasEmptyInput) {
logError(t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: hasEmptyInput }))
return false
}
if (completionFiles.some(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) {
notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) })
return false
}
return !hasEmptyInput
}
const handleSend = async () => {
if (isResponding) {
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
return false
}
if (!checkCanSend())
return
// Process inputs: convert file entities to API format
const processedInputs = { ...formatBooleanInputs(promptConfig?.prompt_variables, inputs) }
promptConfig?.prompt_variables.forEach((variable) => {
const value = processedInputs[variable.key]
if (variable.type === 'file' && value && typeof value === 'object' && !Array.isArray(value)) {
// Convert single file entity to API format
processedInputs[variable.key] = getProcessedFiles([value as FileEntity])[0]
}
else if (variable.type === 'file-list' && Array.isArray(value) && value.length > 0) {
// Convert file entity array to API format
processedInputs[variable.key] = getProcessedFiles(value as FileEntity[])
}
})
const data: Record<string, any> = {
inputs: processedInputs,
}
if (visionConfig.enabled && completionFiles && completionFiles?.length > 0) {
data.files = completionFiles.map((item) => {
if (item.transfer_method === TransferMethod.local_file) {
return {
...item,
url: '',
}
}
return item
})
}
setMessageId(null)
setFeedback({
rating: null,
})
setCompletionRes('')
setWorkflowProcessData(undefined)
resetRunState()
let res: string[] = []
let tempMessageId = ''
if (!isPC) {
onShowRes()
onRunStart()
}
setRespondingTrue()
let isEnd = false
let isTimeout = false;
(async () => {
await sleep(TEXT_GENERATION_TIMEOUT_MS)
if (!isEnd) {
setRespondingFalse()
onCompleted(getCompletionRes(), taskId, false)
resetRunState()
isTimeout = true
}
})()
if (isWorkflow) {
const otherOptions: IOtherOptions = {
isPublicAPI: appSourceType === AppSourceType.webApp,
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
const workflowProcessData = getWorkflowProcessData()
if (workflowProcessData && workflowProcessData.tracing.length > 0) {
setWorkflowProcessData(produce(workflowProcessData, (draft) => {
draft.expand = true
draft.status = WorkflowRunningStatus.Running
}))
}
else {
tempMessageId = workflow_run_id
setCurrentTaskId(task_id || null)
setIsStopping(false)
setWorkflowProcessData({
status: WorkflowRunningStatus.Running,
tracing: [],
expand: false,
resultText: '',
})
}
},
onIterationStart: ({ data }) => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
draft.tracing!.push({
...data,
status: NodeRunningStatus.Running,
expand: true,
})
}))
},
onIterationNext: () => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
const iterations = draft.tracing.find(item => item.node_id === data.node_id
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
iterations?.details!.push([])
}))
},
onIterationFinish: ({ data }) => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
const iterationsIndex = draft.tracing.findIndex(item => item.node_id === data.node_id
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
draft.tracing[iterationsIndex] = {
...data,
expand: !!data.error,
}
}))
},
onLoopStart: ({ data }) => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
draft.tracing!.push({
...data,
status: NodeRunningStatus.Running,
expand: true,
})
}))
},
onLoopNext: () => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
const loops = draft.tracing.find(item => item.node_id === data.node_id
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
loops?.details!.push([])
}))
},
onLoopFinish: ({ data }) => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
const loopsIndex = draft.tracing.findIndex(item => item.node_id === data.node_id
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
draft.tracing[loopsIndex] = {
...data,
expand: !!data.error,
}
}))
},
onNodeStarted: ({ data }) => {
if (data.iteration_id)
return
if (data.loop_id)
return
const workflowProcessData = getWorkflowProcessData()
setWorkflowProcessData(produce(workflowProcessData!, (draft) => {
if (draft.tracing.length > 0) {
const currentIndex = draft.tracing.findIndex(item => item.node_id === data.node_id)
if (currentIndex > -1) {
draft.expand = true
draft.tracing![currentIndex] = {
...data,
status: NodeRunningStatus.Running,
expand: true,
}
}
else {
draft.expand = true
draft.tracing.push({
...data,
status: NodeRunningStatus.Running,
expand: true,
})
}
}
else {
draft.expand = true
draft.tracing!.push({
...data,
status: NodeRunningStatus.Running,
expand: true,
})
}
}))
},
onNodeFinished: ({ data }) => {
if (data.iteration_id)
return
if (data.loop_id)
return
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id
&& (trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || trace.parallel_id === data.execution_metadata?.parallel_id))
if (currentIndex > -1 && draft.tracing) {
draft.tracing[currentIndex] = {
...(draft.tracing[currentIndex].extras
? { extras: draft.tracing[currentIndex].extras }
: {}),
...data,
expand: !!data.error,
}
}
}))
},
onWorkflowFinished: ({ data }) => {
if (isTimeout) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
const workflowStatus = data.status as WorkflowRunningStatus | undefined
const markNodesStopped = (traces?: WorkflowProcess['tracing']) => {
if (!traces)
return
const markTrace = (trace: WorkflowProcess['tracing'][number]) => {
if ([NodeRunningStatus.Running, NodeRunningStatus.Waiting].includes(trace.status as NodeRunningStatus))
trace.status = NodeRunningStatus.Stopped
trace.details?.forEach(detailGroup => detailGroup.forEach(markTrace))
trace.retryDetail?.forEach(markTrace)
trace.parallelDetail?.children?.forEach(markTrace)
}
traces.forEach(markTrace)
}
if (workflowStatus === WorkflowRunningStatus.Stopped) {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.status = WorkflowRunningStatus.Stopped
markNodesStopped(draft.tracing)
}))
setRespondingFalse()
resetRunState()
onCompleted(getCompletionRes(), taskId, false)
isEnd = true
return
}
if (data.error) {
notify({ type: 'error', message: data.error })
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.status = WorkflowRunningStatus.Failed
markNodesStopped(draft.tracing)
}))
setRespondingFalse()
resetRunState()
onCompleted(getCompletionRes(), taskId, false)
isEnd = true
return
}
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.status = WorkflowRunningStatus.Succeeded
draft.files = getFilesInLogs(data.outputs || []) as any[]
}))
if (!data.outputs) {
setCompletionRes('')
}
else {
setCompletionRes(data.outputs)
const isStringOutput = Object.keys(data.outputs).length === 1 && typeof data.outputs[Object.keys(data.outputs)[0]] === 'string'
if (isStringOutput) {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.resultText = data.outputs[Object.keys(data.outputs)[0]]
}))
}
}
setRespondingFalse()
resetRunState()
setMessageId(tempMessageId)
onCompleted(getCompletionRes(), taskId, true)
isEnd = true
},
onTextChunk: (params) => {
const { data: { text } } = params
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.resultText += text
}))
},
onTextReplace: (params) => {
const { data: { text } } = params
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.resultText = text
}))
},
onHumanInputRequired: ({ data: humanInputRequiredData }) => {
const workflowProcessData = getWorkflowProcessData()
setWorkflowProcessData(produce(workflowProcessData!, (draft) => {
if (!draft.humanInputFormDataList) {
draft.humanInputFormDataList = [humanInputRequiredData]
}
else {
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === humanInputRequiredData.node_id)
if (currentFormIndex > -1) {
draft.humanInputFormDataList[currentFormIndex] = humanInputRequiredData
}
else {
draft.humanInputFormDataList.push(humanInputRequiredData)
}
}
const currentIndex = draft.tracing!.findIndex(item => item.node_id === humanInputRequiredData.node_id)
if (currentIndex > -1) {
draft.tracing![currentIndex].status = NodeRunningStatus.Paused
}
}))
},
onHumanInputFormFilled: ({ data: humanInputFilledFormData }) => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
if (draft.humanInputFormDataList?.length) {
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === humanInputFilledFormData.node_id)
draft.humanInputFormDataList.splice(currentFormIndex, 1)
}
if (!draft.humanInputFilledFormDataList) {
draft.humanInputFilledFormDataList = [humanInputFilledFormData]
}
else {
draft.humanInputFilledFormDataList.push(humanInputFilledFormData)
}
}))
},
onHumanInputFormTimeout: ({ data: humanInputFormTimeoutData }) => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
if (draft.humanInputFormDataList?.length) {
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === humanInputFormTimeoutData.node_id)
draft.humanInputFormDataList[currentFormIndex].expiration_time = humanInputFormTimeoutData.expiration_time
}
}))
},
onWorkflowPaused: ({ data: workflowPausedData }) => {
tempMessageId = workflowPausedData.workflow_run_id
const url = `/workflow/${workflowPausedData.workflow_run_id}/events`
sseGet(
url,
{},
otherOptions,
)
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = false
draft.status = WorkflowRunningStatus.Paused
}))
},
}
sendWorkflowMessage(
data,
otherOptions,
appSourceType,
appId,
).catch((error) => {
setRespondingFalse()
resetRunState()
const message = error instanceof Error ? error.message : String(error)
notify({ type: 'error', message })
})
}
else {
sendCompletionMessage(data, {
onData: (data: string, _isFirstMessage: boolean, { messageId, taskId }) => {
tempMessageId = messageId
if (taskId && typeof taskId === 'string' && taskId.trim() !== '')
setCurrentTaskId(prev => prev ?? taskId)
res.push(data)
setCompletionRes(res.join(''))
},
onCompleted: () => {
if (isTimeout) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
setRespondingFalse()
resetRunState()
setMessageId(tempMessageId)
onCompleted(getCompletionRes(), taskId, true)
isEnd = true
},
onMessageReplace: (messageReplace) => {
res = [messageReplace.answer]
setCompletionRes(res.join(''))
},
onError() {
if (isTimeout) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
setRespondingFalse()
resetRunState()
onCompleted(getCompletionRes(), taskId, false)
isEnd = true
},
getAbortController: (abortController) => {
abortControllerRef.current = abortController
},
}, appSourceType, appId)
}
}
const [controlClearMoreLikeThis, setControlClearMoreLikeThis] = useState(0)
useEffect(() => {
if (controlSend) {
handleSend()
setControlClearMoreLikeThis(Date.now())
}
}, [controlSend])
useEffect(() => {
if (controlRetry)
handleSend()
}, [controlRetry])
const isNoData = !runState.completionRes
const renderTextGenerationRes = () => (
<>
{!hideInlineStopButton && isResponding && currentTaskId && (
{!hideInlineStopButton && runState.isResponding && runState.currentTaskId && (
<div className={`mb-3 flex ${isPC ? 'justify-end' : 'justify-center'}`}>
<Button
variant="secondary"
disabled={isStopping}
onClick={handleStop}
disabled={runState.isStopping}
onClick={runState.handleStop}
>
{
isStopping
? <RiLoader2Line className="mr-[5px] h-3.5 w-3.5 animate-spin" />
: <StopCircle className="mr-[5px] h-3.5 w-3.5" />
runState.isStopping
? <span aria-hidden className="i-ri-loader-2-line mr-[5px] h-3.5 w-3.5 animate-spin" />
: <span aria-hidden className="i-ri-stop-circle-fill mr-[5px] h-3.5 w-3.5" />
}
<span className="text-xs font-normal">{t('operation.stopResponding', { ns: 'appDebug' })}</span>
</Button>
@ -650,15 +120,15 @@ const Result: FC<IResultProps> = ({
)}
<TextGenerationRes
isWorkflow={isWorkflow}
workflowProcessData={workflowProcessData}
workflowProcessData={runState.workflowProcessData}
isError={isError}
onRetry={handleSend}
content={completionRes}
messageId={messageId}
content={runState.completionRes}
messageId={runState.messageId}
isInWebApp
moreLikeThis={moreLikeThisEnabled}
onFeedback={handleFeedback}
feedback={feedback}
onFeedback={runState.handleFeedback}
feedback={runState.feedback}
onSave={handleSaveMessage}
isMobile={isMobile}
appSourceType={appSourceType}
@ -666,7 +136,7 @@ const Result: FC<IResultProps> = ({
// isLoading={isCallBatchAPI ? (!completionRes && isResponding) : false}
isLoading={false}
taskId={isCallBatchAPI ? ((taskId as number) < 10 ? `0${taskId}` : `${taskId}`) : undefined}
controlClearMoreLikeThis={controlClearMoreLikeThis}
controlClearMoreLikeThis={runState.controlClearMoreLikeThis}
isShowTextToSpeech={isShowTextToSpeech}
hideProcessDetail
siteInfo={siteInfo}
@ -677,7 +147,7 @@ const Result: FC<IResultProps> = ({
return (
<>
{!isCallBatchAPI && !isWorkflow && (
(isResponding && !completionRes)
(runState.isResponding && !runState.completionRes)
? (
<div className="flex h-full w-full items-center justify-center">
<Loading type="area" />
@ -692,13 +162,13 @@ const Result: FC<IResultProps> = ({
)
)}
{!isCallBatchAPI && isWorkflow && (
(isResponding && !workflowProcessData)
(runState.isResponding && !runState.workflowProcessData)
? (
<div className="flex h-full w-full items-center justify-center">
<Loading type="area" />
</div>
)
: !workflowProcessData
: !runState.workflowProcessData
? <NoData />
: renderTextGenerationRes()
)}

View File

@ -0,0 +1,156 @@
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { PromptConfig } from '@/models/debug'
import type { VisionFile, VisionSettings } from '@/types/app'
import { getProcessedFiles } from '@/app/components/base/file-uploader/utils'
import { TransferMethod } from '@/types/app'
import { formatBooleanInputs } from '@/utils/model-config'
export type ResultInputValue
= | string
| boolean
| number
| string[]
| Record<string, unknown>
| FileEntity
| FileEntity[]
| undefined
type Translate = (key: string, options?: Record<string, unknown>) => string
type ValidationResult = {
canSend: boolean
notification?: {
type: 'error' | 'info'
message: string
}
}
type ValidateResultRequestParams = {
completionFiles: VisionFile[]
inputs: Record<string, ResultInputValue>
isCallBatchAPI: boolean
promptConfig: PromptConfig | null
t: Translate
}
type BuildResultRequestDataParams = {
completionFiles: VisionFile[]
inputs: Record<string, ResultInputValue>
promptConfig: PromptConfig | null
visionConfig: VisionSettings
}
const isMissingRequiredInput = (
variable: PromptConfig['prompt_variables'][number],
value: ResultInputValue,
) => {
if (value === undefined || value === null)
return true
if (variable.type === 'file-list')
return !Array.isArray(value) || value.length === 0
if (['string', 'paragraph', 'number', 'json_object', 'select'].includes(variable.type))
return typeof value !== 'string' ? false : value.trim() === ''
return false
}
const hasPendingLocalFiles = (completionFiles: VisionFile[]) => {
return completionFiles.some(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)
}
export const validateResultRequest = ({
completionFiles,
inputs,
isCallBatchAPI,
promptConfig,
t,
}: ValidateResultRequestParams): ValidationResult => {
if (isCallBatchAPI)
return { canSend: true }
const promptVariables = promptConfig?.prompt_variables
if (!promptVariables?.length) {
if (hasPendingLocalFiles(completionFiles)) {
return {
canSend: false,
notification: {
type: 'info',
message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }),
},
}
}
return { canSend: true }
}
const requiredVariables = promptVariables.filter(({ key, name, required, type }) => {
if (type === 'boolean' || type === 'checkbox')
return false
return (!key || !key.trim()) || (!name || !name.trim()) || required === undefined || required === null || required
})
const missingRequiredVariable = requiredVariables.find(variable => isMissingRequiredInput(variable, inputs[variable.key]))?.name
if (missingRequiredVariable) {
return {
canSend: false,
notification: {
type: 'error',
message: t('errorMessage.valueOfVarRequired', {
ns: 'appDebug',
key: missingRequiredVariable,
}),
},
}
}
if (hasPendingLocalFiles(completionFiles)) {
return {
canSend: false,
notification: {
type: 'info',
message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }),
},
}
}
return { canSend: true }
}
export const buildResultRequestData = ({
completionFiles,
inputs,
promptConfig,
visionConfig,
}: BuildResultRequestDataParams) => {
const processedInputs = {
...formatBooleanInputs(promptConfig?.prompt_variables, inputs as Record<string, string | number | boolean | object>),
}
promptConfig?.prompt_variables.forEach((variable) => {
const value = processedInputs[variable.key]
if (variable.type === 'file' && value && typeof value === 'object' && !Array.isArray(value)) {
processedInputs[variable.key] = getProcessedFiles([value as FileEntity])[0]
return
}
if (variable.type === 'file-list' && Array.isArray(value) && value.length > 0)
processedInputs[variable.key] = getProcessedFiles(value as FileEntity[])
})
return {
inputs: processedInputs,
...(visionConfig.enabled && completionFiles.length > 0
? {
files: completionFiles.map((item) => {
if (item.transfer_method === TransferMethod.local_file)
return { ...item, url: '' }
return item
}),
}
: {}),
}
}

View File

@ -0,0 +1,404 @@
import type { Dispatch, SetStateAction } from 'react'
import type { WorkflowProcess } from '@/app/components/base/chat/types'
import type { IOtherOptions } from '@/service/base'
import type { HumanInputFormTimeoutData, NodeTracing, WorkflowFinishedResponse } from '@/types/workflow'
import { produce } from 'immer'
import { getFilesInLogs } from '@/app/components/base/file-uploader/utils'
import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types'
import { sseGet } from '@/service/base'
type Notify = (payload: { type: 'error' | 'warning', message: string }) => void
type Translate = (key: string, options?: Record<string, unknown>) => string
type CreateWorkflowStreamHandlersParams = {
getCompletionRes: () => string
getWorkflowProcessData: () => WorkflowProcess | undefined
isTimedOut: () => boolean
markEnded: () => void
notify: Notify
onCompleted: (completionRes: string, taskId?: number, success?: boolean) => void
resetRunState: () => void
setCompletionRes: (res: string) => void
setCurrentTaskId: Dispatch<SetStateAction<string | null>>
setIsStopping: Dispatch<SetStateAction<boolean>>
setMessageId: Dispatch<SetStateAction<string | null>>
setRespondingFalse: () => void
setWorkflowProcessData: (data: WorkflowProcess | undefined) => void
t: Translate
taskId?: number
}
const createInitialWorkflowProcess = (): WorkflowProcess => ({
status: WorkflowRunningStatus.Running,
tracing: [],
expand: false,
resultText: '',
})
const updateWorkflowProcess = (
current: WorkflowProcess | undefined,
updater: (draft: WorkflowProcess) => void,
) => {
return produce(current ?? createInitialWorkflowProcess(), updater)
}
const matchParallelTrace = (trace: WorkflowProcess['tracing'][number], data: NodeTracing) => {
return trace.node_id === data.node_id
&& (trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id
|| trace.parallel_id === data.execution_metadata?.parallel_id)
}
const ensureParallelTraceDetails = (details?: NodeTracing['details']) => {
return details?.length ? details : [[]]
}
const appendParallelStart = (current: WorkflowProcess | undefined, data: NodeTracing) => {
return updateWorkflowProcess(current, (draft) => {
draft.expand = true
draft.tracing.push({
...data,
details: ensureParallelTraceDetails(data.details),
status: NodeRunningStatus.Running,
expand: true,
})
})
}
const appendParallelNext = (current: WorkflowProcess | undefined, data: NodeTracing) => {
return updateWorkflowProcess(current, (draft) => {
draft.expand = true
const trace = draft.tracing.find(item => matchParallelTrace(item, data))
if (!trace)
return
trace.details = ensureParallelTraceDetails(trace.details)
trace.details.push([])
})
}
const finishParallelTrace = (current: WorkflowProcess | undefined, data: NodeTracing) => {
return updateWorkflowProcess(current, (draft) => {
draft.expand = true
const traceIndex = draft.tracing.findIndex(item => matchParallelTrace(item, data))
if (traceIndex > -1) {
draft.tracing[traceIndex] = {
...data,
expand: !!data.error,
}
}
})
}
const upsertWorkflowNode = (current: WorkflowProcess | undefined, data: NodeTracing) => {
if (data.iteration_id || data.loop_id)
return current
return updateWorkflowProcess(current, (draft) => {
draft.expand = true
const currentIndex = draft.tracing.findIndex(item => item.node_id === data.node_id)
const nextTrace = {
...data,
status: NodeRunningStatus.Running,
expand: true,
}
if (currentIndex > -1)
draft.tracing[currentIndex] = nextTrace
else
draft.tracing.push(nextTrace)
})
}
const finishWorkflowNode = (current: WorkflowProcess | undefined, data: NodeTracing) => {
if (data.iteration_id || data.loop_id)
return current
return updateWorkflowProcess(current, (draft) => {
const currentIndex = draft.tracing.findIndex(trace => matchParallelTrace(trace, data))
if (currentIndex > -1) {
draft.tracing[currentIndex] = {
...(draft.tracing[currentIndex].extras
? { extras: draft.tracing[currentIndex].extras }
: {}),
...data,
expand: !!data.error,
}
}
})
}
const markNodesStopped = (traces?: WorkflowProcess['tracing']) => {
if (!traces)
return
const markTrace = (trace: WorkflowProcess['tracing'][number]) => {
if ([NodeRunningStatus.Running, NodeRunningStatus.Waiting].includes(trace.status as NodeRunningStatus))
trace.status = NodeRunningStatus.Stopped
trace.details?.forEach(detailGroup => detailGroup.forEach(markTrace))
trace.retryDetail?.forEach(markTrace)
trace.parallelDetail?.children?.forEach(markTrace)
}
traces.forEach(markTrace)
}
const applyWorkflowFinishedState = (
current: WorkflowProcess | undefined,
status: WorkflowRunningStatus,
) => {
return updateWorkflowProcess(current, (draft) => {
draft.status = status
if ([WorkflowRunningStatus.Stopped, WorkflowRunningStatus.Failed].includes(status))
markNodesStopped(draft.tracing)
})
}
const applyWorkflowOutputs = (
current: WorkflowProcess | undefined,
outputs: WorkflowFinishedResponse['data']['outputs'],
) => {
return updateWorkflowProcess(current, (draft) => {
draft.status = WorkflowRunningStatus.Succeeded
draft.files = getFilesInLogs(outputs || []) as unknown as WorkflowProcess['files']
})
}
const appendResultText = (current: WorkflowProcess | undefined, text: string) => {
return updateWorkflowProcess(current, (draft) => {
draft.resultText = `${draft.resultText || ''}${text}`
})
}
const replaceResultText = (current: WorkflowProcess | undefined, text: string) => {
return updateWorkflowProcess(current, (draft) => {
draft.resultText = text
})
}
const updateHumanInputRequired = (
current: WorkflowProcess | undefined,
data: NonNullable<WorkflowProcess['humanInputFormDataList']>[number],
) => {
return updateWorkflowProcess(current, (draft) => {
if (!draft.humanInputFormDataList) {
draft.humanInputFormDataList = [data]
}
else {
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === data.node_id)
if (currentFormIndex > -1)
draft.humanInputFormDataList[currentFormIndex] = data
else
draft.humanInputFormDataList.push(data)
}
const currentIndex = draft.tracing.findIndex(item => item.node_id === data.node_id)
if (currentIndex > -1)
draft.tracing[currentIndex].status = NodeRunningStatus.Paused
})
}
const updateHumanInputFilled = (
current: WorkflowProcess | undefined,
data: NonNullable<WorkflowProcess['humanInputFilledFormDataList']>[number],
) => {
return updateWorkflowProcess(current, (draft) => {
if (draft.humanInputFormDataList?.length) {
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === data.node_id)
if (currentFormIndex > -1)
draft.humanInputFormDataList.splice(currentFormIndex, 1)
}
if (!draft.humanInputFilledFormDataList)
draft.humanInputFilledFormDataList = [data]
else
draft.humanInputFilledFormDataList.push(data)
})
}
const updateHumanInputTimeout = (
current: WorkflowProcess | undefined,
data: HumanInputFormTimeoutData,
) => {
return updateWorkflowProcess(current, (draft) => {
if (!draft.humanInputFormDataList?.length)
return
const currentFormIndex = draft.humanInputFormDataList.findIndex(item => item.node_id === data.node_id)
if (currentFormIndex > -1)
draft.humanInputFormDataList[currentFormIndex].expiration_time = data.expiration_time
})
}
const applyWorkflowPaused = (current: WorkflowProcess | undefined) => {
return updateWorkflowProcess(current, (draft) => {
draft.expand = false
draft.status = WorkflowRunningStatus.Paused
})
}
const serializeWorkflowOutputs = (outputs: WorkflowFinishedResponse['data']['outputs']) => {
if (outputs === undefined || outputs === null)
return ''
if (typeof outputs === 'string')
return outputs
try {
return JSON.stringify(outputs) ?? ''
}
catch {
return String(outputs)
}
}
export const createWorkflowStreamHandlers = ({
getCompletionRes,
getWorkflowProcessData,
isTimedOut,
markEnded,
notify,
onCompleted,
resetRunState,
setCompletionRes,
setCurrentTaskId,
setIsStopping,
setMessageId,
setRespondingFalse,
setWorkflowProcessData,
t,
taskId,
}: CreateWorkflowStreamHandlersParams): IOtherOptions => {
let tempMessageId = ''
const finishWithFailure = () => {
setRespondingFalse()
resetRunState()
onCompleted(getCompletionRes(), taskId, false)
markEnded()
}
const finishWithSuccess = () => {
setRespondingFalse()
resetRunState()
setMessageId(tempMessageId)
onCompleted(getCompletionRes(), taskId, true)
markEnded()
}
const otherOptions: IOtherOptions = {
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
const workflowProcessData = getWorkflowProcessData()
if (workflowProcessData?.tracing.length) {
setWorkflowProcessData(updateWorkflowProcess(workflowProcessData, (draft) => {
draft.expand = true
draft.status = WorkflowRunningStatus.Running
}))
return
}
tempMessageId = workflow_run_id
setCurrentTaskId(task_id || null)
setIsStopping(false)
setWorkflowProcessData(createInitialWorkflowProcess())
},
onIterationStart: ({ data }) => {
setWorkflowProcessData(appendParallelStart(getWorkflowProcessData(), data))
},
onIterationNext: ({ data }) => {
setWorkflowProcessData(appendParallelNext(getWorkflowProcessData(), data))
},
onIterationFinish: ({ data }) => {
setWorkflowProcessData(finishParallelTrace(getWorkflowProcessData(), data))
},
onLoopStart: ({ data }) => {
setWorkflowProcessData(appendParallelStart(getWorkflowProcessData(), data))
},
onLoopNext: ({ data }) => {
setWorkflowProcessData(appendParallelNext(getWorkflowProcessData(), data))
},
onLoopFinish: ({ data }) => {
setWorkflowProcessData(finishParallelTrace(getWorkflowProcessData(), data))
},
onNodeStarted: ({ data }) => {
setWorkflowProcessData(upsertWorkflowNode(getWorkflowProcessData(), data))
},
onNodeFinished: ({ data }) => {
setWorkflowProcessData(finishWorkflowNode(getWorkflowProcessData(), data))
},
onWorkflowFinished: ({ data }) => {
if (isTimedOut()) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
const workflowStatus = data.status as WorkflowRunningStatus | undefined
if (workflowStatus === WorkflowRunningStatus.Stopped) {
setWorkflowProcessData(applyWorkflowFinishedState(getWorkflowProcessData(), WorkflowRunningStatus.Stopped))
finishWithFailure()
return
}
if (data.error) {
notify({ type: 'error', message: data.error })
setWorkflowProcessData(applyWorkflowFinishedState(getWorkflowProcessData(), WorkflowRunningStatus.Failed))
finishWithFailure()
return
}
setWorkflowProcessData(applyWorkflowOutputs(getWorkflowProcessData(), data.outputs))
const serializedOutputs = serializeWorkflowOutputs(data.outputs)
setCompletionRes(serializedOutputs)
if (data.outputs) {
const outputKeys = Object.keys(data.outputs)
const isStringOutput = outputKeys.length === 1 && typeof data.outputs[outputKeys[0]] === 'string'
if (isStringOutput) {
setWorkflowProcessData(updateWorkflowProcess(getWorkflowProcessData(), (draft) => {
draft.resultText = data.outputs[outputKeys[0]]
}))
}
}
finishWithSuccess()
},
onTextChunk: ({ data: { text } }) => {
setWorkflowProcessData(appendResultText(getWorkflowProcessData(), text))
},
onTextReplace: ({ data: { text } }) => {
setWorkflowProcessData(replaceResultText(getWorkflowProcessData(), text))
},
onHumanInputRequired: ({ data }) => {
setWorkflowProcessData(updateHumanInputRequired(getWorkflowProcessData(), data))
},
onHumanInputFormFilled: ({ data }) => {
setWorkflowProcessData(updateHumanInputFilled(getWorkflowProcessData(), data))
},
onHumanInputFormTimeout: ({ data }) => {
setWorkflowProcessData(updateHumanInputTimeout(getWorkflowProcessData(), data))
},
onWorkflowPaused: ({ data }) => {
tempMessageId = data.workflow_run_id
void sseGet(`/workflow/${data.workflow_run_id}/events`, {}, otherOptions)
setWorkflowProcessData(applyWorkflowPaused(getWorkflowProcessData()))
},
}
return otherOptions
}
export {
appendParallelNext,
appendParallelStart,
appendResultText,
applyWorkflowFinishedState,
applyWorkflowOutputs,
applyWorkflowPaused,
finishParallelTrace,
finishWorkflowNode,
markNodesStopped,
replaceResultText,
updateHumanInputFilled,
updateHumanInputRequired,
updateHumanInputTimeout,
upsertWorkflowNode,
}

View File

@ -5646,11 +5646,8 @@
}
},
"app/components/share/text-generation/result/index.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 3
},
"ts/no-explicit-any": {
"count": 3
"count": 1
}
},
"app/components/share/text-generation/run-batch/csv-download/index.tsx": {

View File

@ -0,0 +1,256 @@
import fs from 'node:fs'
import path from 'node:path'
const DIFF_COVERAGE_IGNORE_LINE_TOKEN = 'diff-coverage-ignore-line:'
export function parseChangedLineMap(diff, isTrackedComponentSourceFile) {
const lineMap = new Map()
let currentFile = null
for (const line of diff.split('\n')) {
if (line.startsWith('+++ b/')) {
currentFile = line.slice(6).trim()
continue
}
if (!currentFile || !isTrackedComponentSourceFile(currentFile))
continue
const match = line.match(/^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@/)
if (!match)
continue
const start = Number(match[1])
const count = match[2] ? Number(match[2]) : 1
if (count === 0)
continue
const linesForFile = lineMap.get(currentFile) ?? new Set()
for (let offset = 0; offset < count; offset += 1)
linesForFile.add(start + offset)
lineMap.set(currentFile, linesForFile)
}
return lineMap
}
export function normalizeToRepoRelative(filePath, {
appComponentsCoveragePrefix,
appComponentsPrefix,
repoRoot,
sharedTestPrefix,
webRoot,
}) {
if (!filePath)
return ''
if (filePath.startsWith(appComponentsPrefix) || filePath.startsWith(sharedTestPrefix))
return filePath
if (filePath.startsWith(appComponentsCoveragePrefix))
return `web/${filePath}`
const absolutePath = path.isAbsolute(filePath)
? filePath
: path.resolve(webRoot, filePath)
return path.relative(repoRoot, absolutePath).split(path.sep).join('/')
}
export function getLineHits(entry) {
if (entry?.l && Object.keys(entry.l).length > 0)
return entry.l
const lineHits = {}
for (const [statementId, statement] of Object.entries(entry?.statementMap ?? {})) {
const line = statement?.start?.line
if (!line)
continue
const hits = entry?.s?.[statementId] ?? 0
const previous = lineHits[line]
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits)
}
return lineHits
}
export function getChangedStatementCoverage(entry, changedLines) {
const normalizedChangedLines = [...(changedLines ?? [])].sort((a, b) => a - b)
if (!entry) {
return {
covered: 0,
total: normalizedChangedLines.length,
uncoveredLines: normalizedChangedLines,
}
}
const uncoveredLines = []
let covered = 0
let total = 0
for (const [statementId, statement] of Object.entries(entry.statementMap ?? {})) {
if (!rangeIntersectsChangedLines(statement, changedLines))
continue
total += 1
const hits = entry.s?.[statementId] ?? 0
if (hits > 0) {
covered += 1
continue
}
uncoveredLines.push(statement.start.line)
}
return {
covered,
total,
uncoveredLines: uncoveredLines.sort((a, b) => a - b),
}
}
export function getChangedBranchCoverage(entry, changedLines) {
if (!entry) {
return {
covered: 0,
total: 0,
uncoveredBranches: [],
}
}
const uncoveredBranches = []
let covered = 0
let total = 0
for (const [branchId, branch] of Object.entries(entry.branchMap ?? {})) {
if (!branchIntersectsChangedLines(branch, changedLines))
continue
const hits = Array.isArray(entry.b?.[branchId]) ? entry.b[branchId] : []
const locations = getBranchLocations(branch)
const armCount = Math.max(locations.length, hits.length)
for (let armIndex = 0; armIndex < armCount; armIndex += 1) {
total += 1
if ((hits[armIndex] ?? 0) > 0) {
covered += 1
continue
}
const location = locations[armIndex] ?? branch.loc ?? branch
uncoveredBranches.push({
armIndex,
line: getLocationStartLine(location) ?? branch.line ?? 1,
})
}
}
uncoveredBranches.sort((a, b) => a.line - b.line || a.armIndex - b.armIndex)
return {
covered,
total,
uncoveredBranches,
}
}
export function getIgnoredChangedLinesFromFile(filePath, changedLines) {
if (!fs.existsSync(filePath))
return emptyIgnoreResult(changedLines)
const sourceCode = fs.readFileSync(filePath, 'utf8')
return getIgnoredChangedLinesFromSource(sourceCode, changedLines)
}
export function getIgnoredChangedLinesFromSource(sourceCode, changedLines) {
const ignoredLines = new Map()
const invalidPragmas = []
const changedLineSet = new Set(changedLines ?? [])
const sourceLines = sourceCode.split('\n')
sourceLines.forEach((lineText, index) => {
const lineNumber = index + 1
const commentIndex = lineText.indexOf('//')
if (commentIndex < 0)
return
const tokenIndex = lineText.indexOf(DIFF_COVERAGE_IGNORE_LINE_TOKEN, commentIndex + 2)
if (tokenIndex < 0)
return
const reason = lineText.slice(tokenIndex + DIFF_COVERAGE_IGNORE_LINE_TOKEN.length).trim()
if (!changedLineSet.has(lineNumber))
return
if (!reason) {
invalidPragmas.push({
line: lineNumber,
reason: 'missing ignore reason',
})
return
}
ignoredLines.set(lineNumber, reason)
})
const effectiveChangedLines = new Set(
[...changedLineSet].filter(lineNumber => !ignoredLines.has(lineNumber)),
)
return {
effectiveChangedLines,
ignoredLines,
invalidPragmas,
}
}
function emptyIgnoreResult(changedLines = []) {
return {
effectiveChangedLines: new Set(changedLines),
ignoredLines: new Map(),
invalidPragmas: [],
}
}
function branchIntersectsChangedLines(branch, changedLines) {
if (!changedLines || changedLines.size === 0)
return false
if (rangeIntersectsChangedLines(branch.loc, changedLines))
return true
const locations = getBranchLocations(branch)
if (locations.some(location => rangeIntersectsChangedLines(location, changedLines)))
return true
return branch.line ? changedLines.has(branch.line) : false
}
function getBranchLocations(branch) {
return Array.isArray(branch?.locations) ? branch.locations.filter(Boolean) : []
}
function rangeIntersectsChangedLines(location, changedLines) {
if (!location || !changedLines || changedLines.size === 0)
return false
const startLine = getLocationStartLine(location)
const endLine = getLocationEndLine(location) ?? startLine
if (!startLine || !endLine)
return false
for (const lineNumber of changedLines) {
if (lineNumber >= startLine && lineNumber <= endLine)
return true
}
return false
}
function getLocationStartLine(location) {
return location?.start?.line ?? location?.line ?? null
}
function getLocationEndLine(location) {
return location?.end?.line ?? location?.line ?? null
}

View File

@ -1,6 +1,14 @@
import { execFileSync } from 'node:child_process'
import fs from 'node:fs'
import path from 'node:path'
import {
getChangedBranchCoverage,
getChangedStatementCoverage,
getIgnoredChangedLinesFromFile,
getLineHits,
normalizeToRepoRelative,
parseChangedLineMap,
} from './check-components-diff-coverage-lib.mjs'
import {
collectComponentCoverageExcludedFiles,
COMPONENT_COVERAGE_EXCLUDE_LABEL,
@ -54,7 +62,13 @@ if (changedSourceFiles.length === 0) {
const coverageEntries = new Map()
for (const [file, entry] of Object.entries(coverage)) {
const repoRelativePath = normalizeToRepoRelative(entry.path ?? file)
const repoRelativePath = normalizeToRepoRelative(entry.path ?? file, {
appComponentsCoveragePrefix: APP_COMPONENTS_COVERAGE_PREFIX,
appComponentsPrefix: APP_COMPONENTS_PREFIX,
repoRoot,
sharedTestPrefix: SHARED_TEST_PREFIX,
webRoot,
})
if (!isTrackedComponentSourceFile(repoRelativePath))
continue
@ -74,46 +88,53 @@ for (const [file, entry] of coverageEntries.entries()) {
const overallCoverage = sumCoverageStats(fileCoverageRows)
const diffChanges = getChangedLineMap(baseSha, headSha)
const diffRows = []
const ignoredDiffLines = []
const invalidIgnorePragmas = []
for (const [file, changedLines] of diffChanges.entries()) {
if (!isTrackedComponentSourceFile(file))
continue
const entry = coverageEntries.get(file)
const lineHits = entry ? getLineHits(entry) : {}
const executableChangedLines = [...changedLines]
.filter(line => !entry || lineHits[line] !== undefined)
.sort((a, b) => a - b)
if (executableChangedLines.length === 0) {
diffRows.push({
const ignoreInfo = getIgnoredChangedLinesFromFile(path.join(repoRoot, file), changedLines)
for (const [line, reason] of ignoreInfo.ignoredLines.entries()) {
ignoredDiffLines.push({
file,
moduleName: getModuleName(file),
total: 0,
covered: 0,
uncoveredLines: [],
line,
reason,
})
}
for (const invalidPragma of ignoreInfo.invalidPragmas) {
invalidIgnorePragmas.push({
file,
...invalidPragma,
})
continue
}
const uncoveredLines = executableChangedLines.filter(line => (lineHits[line] ?? 0) === 0)
const statements = getChangedStatementCoverage(entry, ignoreInfo.effectiveChangedLines)
const branches = getChangedBranchCoverage(entry, ignoreInfo.effectiveChangedLines)
diffRows.push({
branches,
file,
ignoredLineCount: ignoreInfo.ignoredLines.size,
moduleName: getModuleName(file),
total: executableChangedLines.length,
covered: executableChangedLines.length - uncoveredLines.length,
uncoveredLines,
statements,
})
}
const diffTotals = diffRows.reduce((acc, row) => {
acc.total += row.total
acc.covered += row.covered
acc.statements.total += row.statements.total
acc.statements.covered += row.statements.covered
acc.branches.total += row.branches.total
acc.branches.covered += row.branches.covered
return acc
}, { total: 0, covered: 0 })
}, {
branches: { total: 0, covered: 0 },
statements: { total: 0, covered: 0 },
})
const diffCoveragePct = percentage(diffTotals.covered, diffTotals.total)
const diffFailures = diffRows.filter(row => row.uncoveredLines.length > 0)
const diffStatementFailures = diffRows.filter(row => row.statements.uncoveredLines.length > 0)
const diffBranchFailures = diffRows.filter(row => row.branches.uncoveredBranches.length > 0)
const overallThresholdFailures = getThresholdFailures(overallCoverage, COMPONENTS_GLOBAL_THRESHOLDS)
const moduleCoverageRows = [...moduleCoverageMap.entries()]
.map(([moduleName, stats]) => ({
@ -139,25 +160,38 @@ appendSummary(buildSummary({
overallThresholdFailures,
moduleCoverageRows,
moduleThresholdFailures,
diffBranchFailures,
diffRows,
diffFailures,
diffCoveragePct,
diffStatementFailures,
diffTotals,
changedSourceFiles,
changedTestFiles,
ignoredDiffLines,
invalidIgnorePragmas,
missingTestTouch,
}))
if (diffFailures.length > 0 && process.env.CI) {
for (const failure of diffFailures.slice(0, 20)) {
const firstLine = failure.uncoveredLines[0] ?? 1
console.log(`::error file=${failure.file},line=${firstLine}::Uncovered changed lines: ${formatLineRanges(failure.uncoveredLines)}`)
if (process.env.CI) {
for (const failure of diffStatementFailures.slice(0, 20)) {
const firstLine = failure.statements.uncoveredLines[0] ?? 1
console.log(`::error file=${failure.file},line=${firstLine}::Uncovered changed statements: ${formatLineRanges(failure.statements.uncoveredLines)}`)
}
for (const failure of diffBranchFailures.slice(0, 20)) {
const firstBranch = failure.branches.uncoveredBranches[0]
const line = firstBranch?.line ?? 1
console.log(`::error file=${failure.file},line=${line}::Uncovered changed branches: ${formatBranchRefs(failure.branches.uncoveredBranches)}`)
}
for (const invalidPragma of invalidIgnorePragmas.slice(0, 20)) {
console.log(`::error file=${invalidPragma.file},line=${invalidPragma.line}::Invalid diff coverage ignore pragma: ${invalidPragma.reason}`)
}
}
if (
overallThresholdFailures.length > 0
|| moduleThresholdFailures.length > 0
|| diffFailures.length > 0
|| diffStatementFailures.length > 0
|| diffBranchFailures.length > 0
|| invalidIgnorePragmas.length > 0
|| (STRICT_TEST_FILE_TOUCH && missingTestTouch)
) {
process.exit(1)
@ -168,11 +202,14 @@ function buildSummary({
overallThresholdFailures,
moduleCoverageRows,
moduleThresholdFailures,
diffBranchFailures,
diffRows,
diffFailures,
diffCoveragePct,
diffStatementFailures,
diffTotals,
changedSourceFiles,
changedTestFiles,
ignoredDiffLines,
invalidIgnorePragmas,
missingTestTouch,
}) {
const lines = [
@ -189,7 +226,8 @@ function buildSummary({
`| Overall tracked statements | ${formatPercent(overallCoverage.statements)} | ${overallCoverage.statements.covered}/${overallCoverage.statements.total}; threshold ${COMPONENTS_GLOBAL_THRESHOLDS.statements}% |`,
`| Overall tracked functions | ${formatPercent(overallCoverage.functions)} | ${overallCoverage.functions.covered}/${overallCoverage.functions.total}; threshold ${COMPONENTS_GLOBAL_THRESHOLDS.functions}% |`,
`| Overall tracked branches | ${formatPercent(overallCoverage.branches)} | ${overallCoverage.branches.covered}/${overallCoverage.branches.total}; threshold ${COMPONENTS_GLOBAL_THRESHOLDS.branches}% |`,
`| Changed executable lines | ${formatPercent({ covered: diffTotals.covered, total: diffTotals.total })} | ${diffTotals.covered}/${diffTotals.total} |`,
`| Changed statements | ${formatDiffPercent(diffTotals.statements)} | ${diffTotals.statements.covered}/${diffTotals.statements.total} |`,
`| Changed branches | ${formatDiffPercent(diffTotals.branches)} | ${diffTotals.branches.covered}/${diffTotals.branches.total} |`,
'',
]
@ -239,20 +277,19 @@ function buildSummary({
lines.push('')
const changedRows = diffRows
.filter(row => row.total > 0)
.filter(row => row.statements.total > 0 || row.branches.total > 0)
.sort((a, b) => {
const aPct = percentage(rowCovered(a), rowTotal(a))
const bPct = percentage(rowCovered(b), rowTotal(b))
return aPct - bPct || a.file.localeCompare(b.file)
const aScore = percentage(a.statements.covered + a.branches.covered, a.statements.total + a.branches.total)
const bScore = percentage(b.statements.covered + b.branches.covered, b.statements.total + b.branches.total)
return aScore - bScore || a.file.localeCompare(b.file)
})
lines.push('<details><summary>Changed file coverage</summary>')
lines.push('')
lines.push('| File | Module | Changed executable lines | Coverage | Uncovered lines |')
lines.push('|---|---|---:|---:|---|')
lines.push('| File | Module | Changed statements | Statement coverage | Uncovered statements | Changed branches | Branch coverage | Uncovered branches | Ignored lines |')
lines.push('|---|---|---:|---:|---|---:|---:|---|---:|')
for (const row of changedRows) {
const rowPct = percentage(row.covered, row.total)
lines.push(`| ${row.file.replace('web/', '')} | ${row.moduleName} | ${row.total} | ${rowPct.toFixed(2)}% | ${formatLineRanges(row.uncoveredLines)} |`)
lines.push(`| ${row.file.replace('web/', '')} | ${row.moduleName} | ${row.statements.total} | ${formatDiffPercent(row.statements)} | ${formatLineRanges(row.statements.uncoveredLines)} | ${row.branches.total} | ${formatDiffPercent(row.branches)} | ${formatBranchRefs(row.branches.uncoveredBranches)} | ${row.ignoredLineCount} |`)
}
lines.push('</details>')
lines.push('')
@ -268,16 +305,41 @@ function buildSummary({
lines.push('')
}
if (diffFailures.length > 0) {
lines.push('Uncovered changed lines:')
for (const row of diffFailures) {
lines.push(`- ${row.file.replace('web/', '')}: ${formatLineRanges(row.uncoveredLines)}`)
if (diffStatementFailures.length > 0) {
lines.push('Uncovered changed statements:')
for (const row of diffStatementFailures) {
lines.push(`- ${row.file.replace('web/', '')}: ${formatLineRanges(row.statements.uncoveredLines)}`)
}
lines.push('')
}
if (diffBranchFailures.length > 0) {
lines.push('Uncovered changed branches:')
for (const row of diffBranchFailures) {
lines.push(`- ${row.file.replace('web/', '')}: ${formatBranchRefs(row.branches.uncoveredBranches)}`)
}
lines.push('')
}
if (ignoredDiffLines.length > 0) {
lines.push('Ignored changed lines via pragma:')
for (const ignoredLine of ignoredDiffLines) {
lines.push(`- ${ignoredLine.file.replace('web/', '')}:${ignoredLine.line} - ${ignoredLine.reason}`)
}
lines.push('')
}
if (invalidIgnorePragmas.length > 0) {
lines.push('Invalid diff coverage ignore pragmas:')
for (const invalidPragma of invalidIgnorePragmas) {
lines.push(`- ${invalidPragma.file.replace('web/', '')}:${invalidPragma.line} - ${invalidPragma.reason}`)
}
lines.push('')
}
lines.push(`Changed source files checked: ${changedSourceFiles.length}`)
lines.push(`Changed executable line coverage: ${diffCoveragePct.toFixed(2)}%`)
lines.push(`Changed statement coverage: ${percentage(diffTotals.statements.covered, diffTotals.statements.total).toFixed(2)}%`)
lines.push(`Changed branch coverage: ${percentage(diffTotals.branches.covered, diffTotals.branches.total).toFixed(2)}%`)
return lines
}
@ -312,34 +374,7 @@ function getChangedFiles(base, head) {
function getChangedLineMap(base, head) {
const diff = execGit(['diff', '--unified=0', '--no-color', '--diff-filter=ACMR', `${base}...${head}`, '--', 'web/app/components'])
const lineMap = new Map()
let currentFile = null
for (const line of diff.split('\n')) {
if (line.startsWith('+++ b/')) {
currentFile = line.slice(6).trim()
continue
}
if (!currentFile || !isTrackedComponentSourceFile(currentFile))
continue
const match = line.match(/^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@/)
if (!match)
continue
const start = Number(match[1])
const count = match[2] ? Number(match[2]) : 1
if (count === 0)
continue
const linesForFile = lineMap.get(currentFile) ?? new Set()
for (let offset = 0; offset < count; offset += 1)
linesForFile.add(start + offset)
lineMap.set(currentFile, linesForFile)
}
return lineMap
return parseChangedLineMap(diff, isTrackedComponentSourceFile)
}
function isAnyComponentSourceFile(filePath) {
@ -407,24 +442,6 @@ function getCoverageStats(entry) {
}
}
function getLineHits(entry) {
if (entry.l && Object.keys(entry.l).length > 0)
return entry.l
const lineHits = {}
for (const [statementId, statement] of Object.entries(entry.statementMap ?? {})) {
const line = statement?.start?.line
if (!line)
continue
const hits = entry.s?.[statementId] ?? 0
const previous = lineHits[line]
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits)
}
return lineHits
}
function sumCoverageStats(rows) {
const total = createEmptyCoverageStats()
for (const row of rows)
@ -479,23 +496,6 @@ function getModuleName(filePath) {
return segments.length === 1 ? '(root)' : segments[0]
}
function normalizeToRepoRelative(filePath) {
if (!filePath)
return ''
if (filePath.startsWith(APP_COMPONENTS_PREFIX) || filePath.startsWith(SHARED_TEST_PREFIX))
return filePath
if (filePath.startsWith(APP_COMPONENTS_COVERAGE_PREFIX))
return `web/${filePath}`
const absolutePath = path.isAbsolute(filePath)
? filePath
: path.resolve(webRoot, filePath)
return path.relative(repoRoot, absolutePath).split(path.sep).join('/')
}
function formatLineRanges(lines) {
if (!lines || lines.length === 0)
return ''
@ -520,6 +520,13 @@ function formatLineRanges(lines) {
return ranges.join(', ')
}
function formatBranchRefs(branches) {
if (!branches || branches.length === 0)
return ''
return branches.map(branch => `${branch.line}[${branch.armIndex}]`).join(', ')
}
function percentage(covered, total) {
if (total === 0)
return 100
@ -530,6 +537,13 @@ function formatPercent(metric) {
return `${percentage(metric.covered, metric.total).toFixed(2)}%`
}
function formatDiffPercent(metric) {
if (metric.total === 0)
return 'n/a'
return `${percentage(metric.covered, metric.total).toFixed(2)}%`
}
function appendSummary(lines) {
const content = `${lines.join('\n')}\n`
if (process.env.GITHUB_STEP_SUMMARY)
@ -550,11 +564,3 @@ function repoRootFromCwd() {
encoding: 'utf8',
}).trim()
}
function rowCovered(row) {
return row.covered
}
function rowTotal(row) {
return row.total
}

View File

@ -92,10 +92,10 @@ export const COMPONENT_MODULE_THRESHOLDS = {
branches: 90,
},
'share': {
lines: 15,
statements: 15,
functions: 20,
branches: 20,
lines: 95,
statements: 95,
functions: 95,
branches: 95,
},
'signin': {
lines: 95,