diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 4571fd1cd1..4a8c61e7d2 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -79,6 +79,29 @@ jobs: find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; find . -name "*.py.bak" -type f -delete + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + package_json_file: web/package.json + run_install: false + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: 24 + cache: pnpm + cache-dependency-path: ./web/pnpm-lock.yaml + + - name: Install web dependencies + run: | + cd web + pnpm install --frozen-lockfile + + - name: ESLint autofix + run: | + cd web + pnpm lint:fix || true + # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - name: mdformat run: | diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 5551030f1e..fdc05d1d65 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -125,7 +125,7 @@ jobs: - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run type-check:tsgo + run: pnpm run type-check - name: Web dead code check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/api/.importlinter b/api/.importlinter index 2dec958788..b676e97591 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -27,7 +27,9 @@ ignore_imports = core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events core.workflow.nodes.loop.loop_node -> core.workflow.graph_events - core.workflow.nodes.node_factory -> core.workflow.graph + core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory + core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory + core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine core.workflow.nodes.iteration.iteration_node -> core.workflow.graph core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels @@ -57,6 +59,252 @@ ignore_imports = core.workflow.graph_engine.manager -> extensions.ext_redis core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis +[importlinter:contract:workflow-external-imports] +name = Workflow External Imports +type = forbidden +source_modules = + core.workflow +forbidden_modules = + configs + controllers + extensions + models + services + tasks + core.agent + core.app + core.base + core.callback_handler + core.datasource + core.db + core.entities + core.errors + core.extension + core.external_data_tool + core.file + core.helper + core.hosting_configuration + core.indexing_runner + core.llm_generator + core.logging + core.mcp + core.memory + core.model_manager + core.moderation + core.ops + core.plugin + core.prompt + core.provider_manager + core.rag + core.repositories + core.schemas + core.tools + core.trigger + core.variables +ignore_imports = + core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory + core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis + core.workflow.graph_engine.layers.observability -> configs + core.workflow.graph_engine.layers.observability -> extensions.otel.runtime + core.workflow.graph_engine.layers.persistence -> core.ops.ops_trace_manager + core.workflow.graph_engine.worker_management.worker_pool -> configs + core.workflow.nodes.agent.agent_node -> core.model_manager + core.workflow.nodes.agent.agent_node -> core.provider_manager + core.workflow.nodes.agent.agent_node -> core.tools.tool_manager + core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor + core.workflow.nodes.datasource.datasource_node -> models.model + core.workflow.nodes.datasource.datasource_node -> models.tools + core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service + core.workflow.nodes.document_extractor.node -> configs + core.workflow.nodes.document_extractor.node -> core.file.file_manager + core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy + core.workflow.nodes.http_request.entities -> configs + core.workflow.nodes.http_request.executor -> configs + core.workflow.nodes.http_request.executor -> core.file.file_manager + core.workflow.nodes.http_request.node -> configs + core.workflow.nodes.http_request.node -> core.tools.tool_file_manager + core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model + core.workflow.nodes.llm.llm_utils -> configs + core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities + core.workflow.nodes.llm.llm_utils -> core.file.models + core.workflow.nodes.llm.llm_utils -> core.model_manager + core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model + core.workflow.nodes.llm.llm_utils -> models.model + core.workflow.nodes.llm.llm_utils -> models.provider + core.workflow.nodes.llm.llm_utils -> services.credit_pool_service + core.workflow.nodes.llm.node -> core.tools.signature + core.workflow.nodes.template_transform.template_transform_node -> configs + core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler + core.workflow.nodes.tool.tool_node -> core.tools.tool_engine + core.workflow.nodes.tool.tool_node -> core.tools.tool_manager + core.workflow.workflow_entry -> configs + core.workflow.workflow_entry -> models.workflow + core.workflow.nodes.agent.agent_node -> core.agent.entities + core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities + core.workflow.graph_engine.layers.persistence -> core.app.entities.app_invoke_entities + core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model + core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform + core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform + core.workflow.nodes.start.entities -> core.app.app_config.entities + core.workflow.nodes.start.start_node -> core.app.app_config.entities + core.workflow.workflow_entry -> core.app.apps.exc + core.workflow.workflow_entry -> core.app.entities.app_invoke_entities + core.workflow.workflow_entry -> core.app.workflow.node_factory + core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager + core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager + core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager + core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager + core.workflow.node_events.node -> core.file + core.workflow.nodes.agent.agent_node -> core.file + core.workflow.nodes.datasource.datasource_node -> core.file + core.workflow.nodes.datasource.datasource_node -> core.file.enums + core.workflow.nodes.document_extractor.node -> core.file + core.workflow.nodes.http_request.executor -> core.file.enums + core.workflow.nodes.http_request.node -> core.file + core.workflow.nodes.http_request.node -> core.file.file_manager + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models + core.workflow.nodes.list_operator.node -> core.file + core.workflow.nodes.llm.file_saver -> core.file + core.workflow.nodes.llm.llm_utils -> core.variables.segments + core.workflow.nodes.llm.node -> core.file + core.workflow.nodes.llm.node -> core.file.file_manager + core.workflow.nodes.llm.node -> core.file.models + core.workflow.nodes.loop.entities -> core.variables.types + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file + core.workflow.nodes.protocols -> core.file + core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models + core.workflow.nodes.tool.tool_node -> core.file + core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer + core.workflow.nodes.tool.tool_node -> models + core.workflow.nodes.trigger_webhook.node -> core.file + core.workflow.runtime.variable_pool -> core.file + core.workflow.runtime.variable_pool -> core.file.file_manager + core.workflow.system_variable -> core.file.models + core.workflow.utils.condition.processor -> core.file + core.workflow.utils.condition.processor -> core.file.file_manager + core.workflow.workflow_entry -> core.file.models + core.workflow.workflow_type_encoder -> core.file.models + core.workflow.nodes.agent.agent_node -> models.model + core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider + core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider + core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider + core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor + core.workflow.nodes.datasource.datasource_node -> core.variables.variables + core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy + core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy + core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy + core.workflow.nodes.llm.node -> core.helper.code_executor + core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor + core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors + core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output + core.workflow.nodes.llm.node -> core.model_manager + core.workflow.graph_engine.layers.persistence -> core.ops.entities.trace_entity + core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform + core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util + core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util + core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util + core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods + core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods + core.workflow.nodes.llm.node -> models.dataset + core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer + core.workflow.nodes.llm.file_saver -> core.tools.signature + core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager + core.workflow.nodes.tool.tool_node -> core.tools.errors + core.workflow.conversation_variable_updater -> core.variables + core.workflow.graph_engine.entities.commands -> core.variables.variables + core.workflow.nodes.agent.agent_node -> core.variables.segments + core.workflow.nodes.answer.answer_node -> core.variables + core.workflow.nodes.code.code_node -> core.variables.segments + core.workflow.nodes.code.code_node -> core.variables.types + core.workflow.nodes.code.entities -> core.variables.types + core.workflow.nodes.datasource.datasource_node -> core.variables.segments + core.workflow.nodes.document_extractor.node -> core.variables + core.workflow.nodes.document_extractor.node -> core.variables.segments + core.workflow.nodes.http_request.executor -> core.variables.segments + core.workflow.nodes.http_request.node -> core.variables.segments + core.workflow.nodes.iteration.iteration_node -> core.variables + core.workflow.nodes.iteration.iteration_node -> core.variables.segments + core.workflow.nodes.iteration.iteration_node -> core.variables.variables + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments + core.workflow.nodes.list_operator.node -> core.variables + core.workflow.nodes.list_operator.node -> core.variables.segments + core.workflow.nodes.llm.node -> core.variables + core.workflow.nodes.loop.loop_node -> core.variables + core.workflow.nodes.parameter_extractor.entities -> core.variables.types + core.workflow.nodes.parameter_extractor.exc -> core.variables.types + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types + core.workflow.nodes.tool.tool_node -> core.variables.segments + core.workflow.nodes.tool.tool_node -> core.variables.variables + core.workflow.nodes.trigger_webhook.node -> core.variables.types + core.workflow.nodes.trigger_webhook.node -> core.variables.variables + core.workflow.nodes.variable_aggregator.entities -> core.variables.types + core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments + core.workflow.nodes.variable_assigner.common.helpers -> core.variables + core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts + core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types + core.workflow.nodes.variable_assigner.v1.node -> core.variables + core.workflow.nodes.variable_assigner.v2.helpers -> core.variables + core.workflow.nodes.variable_assigner.v2.node -> core.variables + core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts + core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments + core.workflow.runtime.read_only_wrappers -> core.variables.segments + core.workflow.runtime.variable_pool -> core.variables + core.workflow.runtime.variable_pool -> core.variables.consts + core.workflow.runtime.variable_pool -> core.variables.segments + core.workflow.runtime.variable_pool -> core.variables.variables + core.workflow.utils.condition.processor -> core.variables + core.workflow.utils.condition.processor -> core.variables.segments + core.workflow.variable_loader -> core.variables + core.workflow.variable_loader -> core.variables.consts + core.workflow.workflow_type_encoder -> core.variables + core.workflow.graph_engine.manager -> extensions.ext_redis + core.workflow.nodes.agent.agent_node -> extensions.ext_database + core.workflow.nodes.datasource.datasource_node -> extensions.ext_database + core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis + core.workflow.nodes.llm.file_saver -> extensions.ext_database + core.workflow.nodes.llm.llm_utils -> extensions.ext_database + core.workflow.nodes.llm.node -> extensions.ext_database + core.workflow.nodes.tool.tool_node -> extensions.ext_database + core.workflow.workflow_entry -> extensions.otel.runtime + core.workflow.nodes.agent.agent_node -> models + core.workflow.nodes.base.node -> models.enums + core.workflow.nodes.llm.llm_utils -> models.provider_ids + core.workflow.nodes.llm.node -> models.model + core.workflow.workflow_entry -> models.enums + core.workflow.nodes.agent.agent_node -> services + core.workflow.nodes.tool.tool_node -> services + [importlinter:contract:rsc] name = RSC type = layers diff --git a/api/commands.py b/api/commands.py index aa7b731a27..3d68de4cb4 100644 --- a/api/commands.py +++ b/api/commands.py @@ -950,6 +950,346 @@ def clean_workflow_runs( ) +@click.command( + "archive-workflow-runs", + help="Archive workflow runs for paid plan tenants to S3-compatible storage.", +) +@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.") +@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.") +@click.option( + "--from-days-ago", + default=None, + type=click.IntRange(min=0), + help="Lower bound in days ago (older). Must be paired with --to-days-ago.", +) +@click.option( + "--to-days-ago", + default=None, + type=click.IntRange(min=0), + help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created at or after this timestamp (UTC if no timezone).", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created before this timestamp (UTC if no timezone).", +) +@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.") +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.") +@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.") +@click.option("--dry-run", is_flag=True, help="Preview without archiving.") +@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.") +def archive_workflow_runs( + tenant_ids: str | None, + before_days: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + batch_size: int, + workers: int, + limit: int | None, + dry_run: bool, + delete_after_archive: bool, +): + """ + Archive workflow runs for paid plan tenants older than the specified days. + + This command archives the following tables to storage: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + + The workflow_runs and workflow_app_logs tables are preserved for UI listing. + """ + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + run_started_at = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting workflow run archiving at {run_started_at.isoformat()}.", + fg="white", + ) + ) + + if (start_from is None) ^ (end_before is None): + click.echo(click.style("start-from and end-before must be provided together.", fg="red")) + return + + if (from_days_ago is None) ^ (to_days_ago is None): + click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red")) + return + + if from_days_ago is not None and to_days_ago is not None: + if start_from or end_before: + click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red")) + return + if from_days_ago <= to_days_ago: + click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red")) + return + now = datetime.datetime.now() + start_from = now - datetime.timedelta(days=from_days_ago) + end_before = now - datetime.timedelta(days=to_days_ago) + before_days = 0 + + if start_from and end_before and start_from >= end_before: + click.echo(click.style("start-from must be earlier than end-before.", fg="red")) + return + if workers < 1: + click.echo(click.style("workers must be at least 1.", fg="red")) + return + + archiver = WorkflowRunArchiver( + days=before_days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + workers=workers, + tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None, + limit=limit, + dry_run=dry_run, + delete_after_archive=delete_after_archive, + ) + summary = archiver.run() + click.echo( + click.style( + f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, " + f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"time={summary.total_elapsed_time:.2f}s", + fg="cyan", + ) + ) + + run_finished_at = datetime.datetime.now(datetime.UTC) + elapsed = run_finished_at - run_started_at + click.echo( + click.style( + f"Workflow run archiving completed. start={run_started_at.isoformat()} " + f"end={run_finished_at.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + +@click.command( + "restore-workflow-runs", + help="Restore archived workflow runs from S3-compatible storage.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to restore.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.") +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.") +@click.option("--dry-run", is_flag=True, help="Preview without restoring.") +def restore_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + workers: int, + limit: int, + dry_run: bool, +): + """ + Restore an archived workflow run from storage to the database. + + This restores the following tables: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + """ + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch restore.") + if workers < 1: + raise click.BadParameter("workers must be at least 1") + + start_time = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.", + fg="white", + ) + ) + + restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers) + if run_id: + results = [restorer.restore_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = restorer.restore_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Restore completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + +@click.command( + "delete-archived-workflow-runs", + help="Delete archived workflow runs from the database.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to delete.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.") +@click.option("--dry-run", is_flag=True, help="Preview without deleting.") +def delete_archived_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + limit: int, + dry_run: bool, +): + """ + Delete archived workflow runs from the database. + """ + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch delete.") + + start_time = datetime.datetime.now(datetime.UTC) + target_desc = f"workflow run {run_id}" if run_id else "workflow runs" + click.echo( + click.style( + f"Starting delete of {target_desc} at {start_time.isoformat()}.", + fg="white", + ) + ) + + deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run) + if run_id: + results = [deleter.delete_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = deleter.delete_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + for result in results: + if result.success: + click.echo( + click.style( + f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} " + f"workflow run {result.run_id} (tenant={result.tenant_id})", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Failed to delete workflow run {result.run_id}: {result.error}", + fg="red", + ) + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Delete completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + @click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") @click.command("clear-orphaned-file-records", help="Clear orphaned file records.") def clear_orphaned_file_records(force: bool): diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 360be16beb..2d465c8cf4 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -3,6 +3,7 @@ Flask App Context - Flask implementation of AppContext interface. """ import contextvars +import threading from collections.abc import Generator from contextlib import contextmanager from typing import Any, final @@ -118,6 +119,7 @@ class FlaskExecutionContext: self._context_vars = context_vars self._user = user self._flask_app = flask_app + self._local = threading.local() @property def app_context(self) -> FlaskAppContext: @@ -136,47 +138,39 @@ class FlaskExecutionContext: def __enter__(self) -> "FlaskExecutionContext": """Enter the Flask execution context.""" - # Restore context variables + # Restore non-Flask context variables to avoid leaking Flask tokens across threads for var, val in self._context_vars.items(): var.set(val) - # Save current user from g if available - saved_user = None - if hasattr(g, "_login_user"): - saved_user = g._login_user - # Enter Flask app context - self._cm = self._app_context.enter() - self._cm.__enter__() + cm = self._app_context.enter() + self._local.cm = cm + cm.__enter__() # Restore user in new app context - if saved_user is not None: - g._login_user = saved_user + if self._user is not None: + g._login_user = self._user return self def __exit__(self, *args: Any) -> None: """Exit the Flask execution context.""" - if hasattr(self, "_cm"): - self._cm.__exit__(*args) + cm = getattr(self._local, "cm", None) + if cm is not None: + cm.__exit__(*args) @contextmanager def enter(self) -> Generator[None, None, None]: """Enter Flask execution context as context manager.""" - # Restore context variables + # Restore non-Flask context variables to avoid leaking Flask tokens across threads for var, val in self._context_vars.items(): var.set(val) - # Save current user from g if available - saved_user = None - if hasattr(g, "_login_user"): - saved_user = g._login_user - # Enter Flask app context with self._flask_app.app_context(): # Restore user in new app context - if saved_user is not None: - g._login_user = saved_user + if self._user is not None: + g._login_user = self._user yield diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index fa67fb8154..6736f24a2e 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -11,7 +11,10 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.workflow.enums import WorkflowExecutionStatus from extensions.ext_database import db -from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from fields.workflow_app_log_fields import ( + build_workflow_app_log_pagination_model, + build_workflow_archived_log_pagination_model, +) from libs.login import login_required from models import App from models.model import AppMode @@ -61,6 +64,7 @@ console_ns.schema_model( # Register model for flask_restx to avoid dict type issues in Swagger workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) +workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns) @console_ns.route("/apps//workflow-app-logs") @@ -99,3 +103,33 @@ class WorkflowAppLogApi(Resource): ) return workflow_app_log_pagination + + +@console_ns.route("/apps//workflow-archived-logs") +class WorkflowArchivedLogApi(Resource): + @console_ns.doc("get_workflow_archived_logs") + @console_ns.doc(description="Get workflow archived execution logs") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) + @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + @marshal_with(workflow_archived_log_pagination_model) + def get(self, app_model: App): + """ + Get workflow archived logs + """ + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + workflow_app_service = WorkflowAppService() + with Session(db.engine) as session: + workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs( + session=session, + app_model=app_model, + page=args.page, + limit=args.limit, + ) + + return workflow_app_log_pagination diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 8f1871f1e9..fa74f8aea1 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,12 +1,15 @@ +from datetime import UTC, datetime, timedelta from typing import Literal, cast from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required +from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields from fields.workflow_run_fields import ( @@ -19,14 +22,17 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value from libs.login import current_user, login_required -from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom +from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME from services.workflow_run_service import WorkflowRunService # Workflow run status choices for filtering WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] +EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600 # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -93,6 +99,15 @@ workflow_run_node_execution_list_model = console_ns.model( "WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy ) +workflow_run_export_fields = console_ns.model( + "WorkflowRunExport", + { + "status": fields.String(description="Export status: success/failed"), + "presigned_url": fields.String(description="Pre-signed URL for download", required=False), + "presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False), + }, +) + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -181,6 +196,56 @@ class AdvancedChatAppWorkflowRunListApi(Resource): return result +@console_ns.route("/apps//workflow-runs//export") +class WorkflowRunExportApi(Resource): + @console_ns.doc("get_workflow_run_export_url") + @console_ns.doc(description="Generate a download URL for an archived workflow run.") + @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @console_ns.response(200, "Export URL generated", workflow_run_export_fields) + @setup_required + @login_required + @account_initialization_required + @get_app_model() + def get(self, app_model: App, run_id: str): + tenant_id = str(app_model.tenant_id) + app_id = str(app_model.id) + run_id_str = str(run_id) + + run_created_at = db.session.scalar( + select(WorkflowArchiveLog.run_created_at) + .where( + WorkflowArchiveLog.tenant_id == tenant_id, + WorkflowArchiveLog.app_id == app_id, + WorkflowArchiveLog.workflow_run_id == run_id_str, + ) + .limit(1) + ) + if not run_created_at: + return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404 + + prefix = ( + f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/" + f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}" + ) + archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + + try: + archive_storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + return {"code": "archive_storage_not_configured", "message": str(e)}, 500 + + presigned_url = archive_storage.generate_presigned_url( + archive_key, + expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS, + ) + expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS) + return { + "status": "success", + "presigned_url": presigned_url, + "presigned_url_expires_at": expires_at.isoformat(), + }, 200 + + @console_ns.route("/apps//advanced-chat/workflow-runs/count") class AdvancedChatAppWorkflowRunCountApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs_count") diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 6951c906e9..d3811e2d1b 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,6 +1,7 @@ from flask_restx import Resource, fields +from werkzeug.exceptions import Unauthorized -from libs.login import current_account_with_tenant, login_required +from libs.login import current_account_with_tenant, current_user, login_required from services.feature_service import FeatureService from . import console_ns @@ -39,5 +40,21 @@ class SystemFeatureApi(Resource): ), ) def get(self): - """Get system-wide feature configuration""" - return FeatureService.get_system_features().model_dump() + """Get system-wide feature configuration + + NOTE: This endpoint is unauthenticated by design, as it provides system features + data required for dashboard initialization. + + Authentication would create circular dependency (can't login without dashboard loading). + + Only non-sensitive configuration data should be returned by this endpoint. + """ + # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated` + # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request` + # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will + # raise `Unauthorized` exception if authentication token is not provided. + try: + is_authenticated = current_user.is_authenticated + except Unauthorized: + is_authenticated = False + return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump() diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index c800c0e4e1..49ff4f57dc 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -261,17 +261,6 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" - args = {} - if "data" in request.form: - args = json.loads(request.form["data"]) - if "doc_form" not in args: - args["doc_form"] = "text_model" - if "doc_language" not in args: - args["doc_language"] = "English" - - # get dataset info - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -280,6 +269,18 @@ class DocumentAddByFileApi(DatasetApiResource): if dataset.provider == "external": raise ValueError("External datasets are not supported.") + args = {} + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = dataset.chunk_structure or "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" + + # get dataset info + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + indexing_technique = args.get("indexing_technique") or dataset.indexing_technique if not indexing_technique: raise ValueError("indexing_technique is required.") @@ -370,17 +371,6 @@ class DocumentUpdateByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" - args = {} - if "data" in request.form: - args = json.loads(request.form["data"]) - if "doc_form" not in args: - args["doc_form"] = "text_model" - if "doc_language" not in args: - args["doc_language"] = "English" - - # get dataset info - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -389,6 +379,18 @@ class DocumentUpdateByFileApi(DatasetApiResource): if dataset.provider == "external": raise ValueError("External datasets are not supported.") + args = {} + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = dataset.chunk_structure or "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" + + # get dataset info + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + # indexing_technique is already set in dataset since this is an update args["indexing_technique"] = dataset.indexing_technique diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 0157521ae9..34d02a1e51 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -9,13 +9,13 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, RagPipelineGenerateEntity, ) +from core.app.workflow.node_factory import DifyNodeFactory from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.enums import WorkflowType from core.workflow.graph import Graph from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.runtime import GraphRuntimeState, VariablePool diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 7adf3504ac..2ca153f835 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -25,6 +25,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine.layers.base import GraphEngineLayer @@ -53,7 +54,6 @@ from core.workflow.graph_events import ( ) from core.workflow.graph_events.graph import GraphRunAbortedEvent from core.workflow.nodes import NodeType -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py new file mode 100644 index 0000000000..172ee5d703 --- /dev/null +++ b/api/core/app/workflow/__init__.py @@ -0,0 +1,3 @@ +from .node_factory import DifyNodeFactory + +__all__ = ["DifyNodeFactory"] diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/app/workflow/node_factory.py similarity index 98% rename from api/core/workflow/nodes/node_factory.py rename to api/core/app/workflow/node_factory.py index 5c04e5110f..e0a0059a38 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -15,6 +15,7 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.nodes.http_request.node import HttpRequestNode +from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol from core.workflow.nodes.template_transform.template_renderer import ( CodeExecutorJinja2TemplateRenderer, @@ -23,8 +24,6 @@ from core.workflow.nodes.template_transform.template_renderer import ( from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from libs.typing import is_str, is_str_dict -from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING - if TYPE_CHECKING: from core.workflow.entities import GraphInitParams from core.workflow.runtime import GraphRuntimeState diff --git a/api/core/workflow/context/execution_context.py b/api/core/workflow/context/execution_context.py index d951c95d68..e3007530f0 100644 --- a/api/core/workflow/context/execution_context.py +++ b/api/core/workflow/context/execution_context.py @@ -3,6 +3,7 @@ Execution Context - Abstracted context management for workflow execution. """ import contextvars +import threading from abc import ABC, abstractmethod from collections.abc import Callable, Generator from contextlib import AbstractContextManager, contextmanager @@ -88,6 +89,7 @@ class ExecutionContext: self._app_context = app_context self._context_vars = context_vars self._user = user + self._local = threading.local() @property def app_context(self) -> AppContext | None: @@ -125,14 +127,16 @@ class ExecutionContext: def __enter__(self) -> "ExecutionContext": """Enter the execution context.""" - self._cm = self.enter() - self._cm.__enter__() + cm = self.enter() + self._local.cm = cm + cm.__enter__() return self def __exit__(self, *args: Any) -> None: """Exit the execution context.""" - if hasattr(self, "_cm"): - self._cm.__exit__(*args) + cm = getattr(self._local, "cm", None) + if cm is not None: + cm.__exit__(*args) class NullAppContext(AppContext): diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index 95db5c5c92..6c69ea5df0 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -11,7 +11,6 @@ import time from collections.abc import Sequence from datetime import datetime from typing import TYPE_CHECKING, final -from uuid import uuid4 from typing_extensions import override @@ -113,7 +112,7 @@ class Worker(threading.Thread): self._ready_queue.task_done() except Exception as e: error_event = NodeRunFailedEvent( - id=str(uuid4()), + id=node.execution_id, node_id=node.id, node_type=node.node_type, in_iteration_id=None, diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 55c8db40ea..63e0260341 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -469,12 +469,8 @@ class Node(Generic[NodeDataT]): import core.workflow.nodes as _nodes_pkg for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."): - # Avoid importing modules that depend on the registry to prevent circular imports - # e.g. node_factory imports node_mapping which builds the mapping here. - if _modname in { - "core.workflow.nodes.node_factory", - "core.workflow.nodes.node_mapping", - }: + # Avoid importing modules that depend on the registry to prevent circular imports. + if _modname == "core.workflow.nodes.node_mapping": continue importlib.import_module(_modname) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 569a4196fb..ced996e7e0 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -588,11 +588,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): def _create_graph_engine(self, index: int, item: object): # Import dependencies + from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState # Create GraphInitParams from node attributes diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 1f9fc8a115..07d05966cc 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -413,11 +413,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): def _create_graph_engine(self, start_at: datetime, root_node_id: str): # Import dependencies + from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState # Create GraphInitParams from node attributes diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index ee37314721..c7bcc66c8b 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -7,6 +7,7 @@ from typing import Any from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.file.models import File from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams @@ -19,7 +20,6 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 51e2c6cdd5..46885761a1 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + archive_workflow_runs, clean_expired_messages, clean_workflow_runs, cleanup_orphaned_draft_variables, @@ -11,6 +12,7 @@ def init_app(app: DifyApp): clear_orphaned_file_records, convert_to_agent_apps, create_tenant, + delete_archived_workflow_runs, extract_plugins, extract_unique_plugins, file_usage, @@ -24,6 +26,7 @@ def init_app(app: DifyApp): reset_email, reset_encrypt_key_pair, reset_password, + restore_workflow_runs, setup_datasource_oauth_client, setup_system_tool_oauth_client, setup_system_trigger_oauth_client, @@ -58,6 +61,9 @@ def init_app(app: DifyApp): setup_datasource_oauth_client, transform_datasource_credentials, install_rag_pipeline_plugins, + archive_workflow_runs, + delete_archived_workflow_runs, + restore_workflow_runs, clean_workflow_runs, clean_expired_messages, ] diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 0ebc03a98c..ae70356322 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -2,7 +2,12 @@ from flask_restx import Namespace, fields from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields from fields.member_fields import build_simple_account_model, simple_account_fields -from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields +from fields.workflow_run_fields import ( + build_workflow_run_for_archived_log_model, + build_workflow_run_for_log_model, + workflow_run_for_archived_log_fields, + workflow_run_for_log_fields, +) from libs.helper import TimestampField workflow_app_log_partial_fields = { @@ -34,6 +39,33 @@ def build_workflow_app_log_partial_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowAppLogPartial", copied_fields) +workflow_archived_log_partial_fields = { + "id": fields.String, + "workflow_run": fields.Nested(workflow_run_for_archived_log_fields, allow_null=True), + "trigger_metadata": fields.Raw, + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "created_at": TimestampField, +} + + +def build_workflow_archived_log_partial_model(api_or_ns: Namespace): + """Build the workflow archived log partial model for the API or Namespace.""" + workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns) + simple_account_model = build_simple_account_model(api_or_ns) + simple_end_user_model = build_simple_end_user_model(api_or_ns) + + copied_fields = workflow_archived_log_partial_fields.copy() + copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True) + copied_fields["created_by_account"] = fields.Nested( + simple_account_model, attribute="created_by_account", allow_null=True + ) + copied_fields["created_by_end_user"] = fields.Nested( + simple_end_user_model, attribute="created_by_end_user", allow_null=True + ) + return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields) + + workflow_app_log_pagination_fields = { "page": fields.Integer, "limit": fields.Integer, @@ -51,3 +83,21 @@ def build_workflow_app_log_pagination_model(api_or_ns: Namespace): copied_fields = workflow_app_log_pagination_fields.copy() copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model)) return api_or_ns.model("WorkflowAppLogPagination", copied_fields) + + +workflow_archived_log_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(workflow_archived_log_partial_fields)), +} + + +def build_workflow_archived_log_pagination_model(api_or_ns: Namespace): + """Build the workflow archived log pagination model for the API or Namespace.""" + workflow_archived_log_partial_model = build_workflow_archived_log_partial_model(api_or_ns) + + copied_fields = workflow_archived_log_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model)) + return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields) diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 476025064f..35bb442c59 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -23,6 +23,19 @@ def build_workflow_run_for_log_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) +workflow_run_for_archived_log_fields = { + "id": fields.String, + "status": fields.String, + "triggered_from": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, +} + + +def build_workflow_run_for_archived_log_model(api_or_ns: Namespace): + return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields) + + workflow_run_for_list_fields = { "id": fields.String, "version": fields.String, diff --git a/api/libs/archive_storage.py b/api/libs/archive_storage.py index f84d226447..66b57ac661 100644 --- a/api/libs/archive_storage.py +++ b/api/libs/archive_storage.py @@ -7,7 +7,6 @@ to S3-compatible object storage. import base64 import datetime -import gzip import hashlib import logging from collections.abc import Generator @@ -39,7 +38,7 @@ class ArchiveStorage: """ S3-compatible storage client for archiving or exporting. - This client provides methods for storing and retrieving archived data in JSONL+gzip format. + This client provides methods for storing and retrieving archived data in JSONL format. """ def __init__(self, bucket: str): @@ -69,7 +68,10 @@ class ArchiveStorage: aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY, aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY, region_name=dify_config.ARCHIVE_STORAGE_REGION, - config=Config(s3={"addressing_style": "path"}), + config=Config( + s3={"addressing_style": "path"}, + max_pool_connections=64, + ), ) # Verify bucket accessibility @@ -100,12 +102,18 @@ class ArchiveStorage: """ checksum = hashlib.md5(data).hexdigest() try: - self.client.put_object( + response = self.client.put_object( Bucket=self.bucket, Key=key, Body=data, ContentMD5=self._content_md5(data), ) + etag = response.get("ETag") + if not etag: + raise ArchiveStorageError(f"Missing ETag for '{key}'") + normalized_etag = etag.strip('"') + if normalized_etag != checksum: + raise ArchiveStorageError(f"ETag mismatch for '{key}': expected={checksum}, actual={normalized_etag}") logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum) return checksum except ClientError as e: @@ -240,19 +248,18 @@ class ArchiveStorage: return base64.b64encode(hashlib.md5(data).digest()).decode() @staticmethod - def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes: + def serialize_to_jsonl(records: list[dict[str, Any]]) -> bytes: """ - Serialize records to gzipped JSONL format. + Serialize records to JSONL format. Args: records: List of dictionaries to serialize Returns: - Gzipped JSONL bytes + JSONL bytes """ lines = [] for record in records: - # Convert datetime objects to ISO format strings serialized = ArchiveStorage._serialize_record(record) lines.append(orjson.dumps(serialized)) @@ -260,23 +267,22 @@ class ArchiveStorage: if jsonl_content: jsonl_content += b"\n" - return gzip.compress(jsonl_content) + return jsonl_content @staticmethod - def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]: + def deserialize_from_jsonl(data: bytes) -> list[dict[str, Any]]: """ - Deserialize gzipped JSONL data to records. + Deserialize JSONL data to records. Args: - data: Gzipped JSONL bytes + data: JSONL bytes Returns: List of dictionaries """ - jsonl_content = gzip.decompress(data) records = [] - for line in jsonl_content.splitlines(): + for line in data.splitlines(): if line: records.append(orjson.loads(line)) diff --git a/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py new file mode 100644 index 0000000000..5e7298af54 --- /dev/null +++ b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py @@ -0,0 +1,95 @@ +"""create workflow_archive_logs + +Revision ID: 9d77545f524e +Revises: f9f6d18a37f9 +Create Date: 2026-01-06 17:18:56.292479 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +def _is_pg(conn): + return conn.dialect.name == "postgresql" + +# revision identifiers, used by Alembic. +revision = '9d77545f524e' +down_revision = 'f9f6d18a37f9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if _is_pg(conn): + op.create_table('workflow_archive_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), + sa.Column('log_id', models.types.StringUUID(), nullable=True), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('log_created_at', sa.DateTime(), nullable=True), + sa.Column('log_created_from', sa.String(length=255), nullable=True), + sa.Column('run_version', sa.String(length=255), nullable=False), + sa.Column('run_status', sa.String(length=255), nullable=False), + sa.Column('run_triggered_from', sa.String(length=255), nullable=False), + sa.Column('run_error', models.types.LongText(), nullable=True), + sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('run_created_at', sa.DateTime(), nullable=False), + sa.Column('run_finished_at', sa.DateTime(), nullable=True), + sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('trigger_metadata', models.types.LongText(), nullable=True), + sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey') + ) + else: + op.create_table('workflow_archive_logs', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('log_id', models.types.StringUUID(), nullable=True), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('log_created_at', sa.DateTime(), nullable=True), + sa.Column('log_created_from', sa.String(length=255), nullable=True), + sa.Column('run_version', sa.String(length=255), nullable=False), + sa.Column('run_status', sa.String(length=255), nullable=False), + sa.Column('run_triggered_from', sa.String(length=255), nullable=False), + sa.Column('run_error', models.types.LongText(), nullable=True), + sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('run_created_at', sa.DateTime(), nullable=False), + sa.Column('run_finished_at', sa.DateTime(), nullable=True), + sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('trigger_metadata', models.types.LongText(), nullable=True), + sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey') + ) + with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op: + batch_op.create_index('workflow_archive_log_app_idx', ['tenant_id', 'app_id'], unique=False) + batch_op.create_index('workflow_archive_log_run_created_at_idx', ['run_created_at'], unique=False) + batch_op.create_index('workflow_archive_log_workflow_run_id_idx', ['workflow_run_id'], unique=False) + + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op: + batch_op.drop_index('workflow_archive_log_workflow_run_id_idx') + batch_op.drop_index('workflow_archive_log_run_created_at_idx') + batch_op.drop_index('workflow_archive_log_app_idx') + + op.drop_table('workflow_archive_logs') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 91171a4bef..74b33130ef 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -103,6 +103,7 @@ from .workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, + WorkflowArchiveLog, WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom, @@ -203,6 +204,7 @@ __all__ = [ "Workflow", "WorkflowAppLog", "WorkflowAppLogCreatedFrom", + "WorkflowArchiveLog", "WorkflowNodeExecutionModel", "WorkflowNodeExecutionOffload", "WorkflowNodeExecutionTriggeredFrom", diff --git a/api/models/workflow.py b/api/models/workflow.py index 2ff47e87b9..0efb3a4e44 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1163,6 +1163,69 @@ class WorkflowAppLog(TypeBase): } +class WorkflowArchiveLog(TypeBase): + """ + Workflow archive log. + + Stores essential workflow run snapshot data for archived app logs. + + Field sources: + - Shared fields (tenant/app/workflow/run ids, created_by*): from WorkflowRun for consistency. + - log_* fields: from WorkflowAppLog when present; null if the run has no app log. + - run_* fields: workflow run snapshot fields from WorkflowRun. + - trigger_metadata: snapshot from WorkflowTriggerLog when present. + """ + + __tablename__ = "workflow_archive_logs" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_archive_log_pkey"), + sa.Index("workflow_archive_log_app_idx", "tenant_id", "app_id"), + sa.Index("workflow_archive_log_workflow_run_id_idx", "workflow_run_id"), + sa.Index("workflow_archive_log_run_created_at_idx", "run_created_at"), + ) + + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False + ) + + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + + log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True) + + run_version: Mapped[str] = mapped_column(String(255), nullable=False) + run_status: Mapped[str] = mapped_column(String(255), nullable=False) + run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False) + run_error: Mapped[str | None] = mapped_column(LongText, nullable=True) + run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) + run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) + run_total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + run_created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + run_finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + run_exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + + trigger_metadata: Mapped[str | None] = mapped_column(LongText, nullable=True) + archived_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + + @property + def workflow_run_summary(self) -> dict[str, Any]: + return { + "id": self.workflow_run_id, + "status": self.run_status, + "triggered_from": self.run_triggered_from, + "elapsed_time": self.run_elapsed_time, + "total_tokens": self.run_total_tokens, + } + + class ConversationVariable(TypeBase): __tablename__ = "workflow_conversation_variables" diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 479eb1ff54..5b3f635301 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from models.workflow import WorkflowNodeExecutionModel +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): @@ -209,3 +209,23 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr The number of executions deleted """ ... + + def get_offloads_by_execution_ids( + self, + session: Session, + node_execution_ids: Sequence[str], + ) -> Sequence[WorkflowNodeExecutionOffload]: + """ + Get offload records by node execution IDs. + + This method retrieves workflow node execution offload records + that belong to the given node execution IDs. + + Args: + session: The database session to use + node_execution_ids: List of node execution IDs to filter by + + Returns: + A sequence of WorkflowNodeExecutionOffload instances + """ + ... diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 1a2b84fdf9..1d3954571f 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -45,7 +45,7 @@ from core.workflow.enums import WorkflowType from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( AverageInteractionStats, @@ -270,6 +270,58 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_archived_run_ids( + self, + session: Session, + run_ids: Sequence[str], + ) -> set[str]: + """ + Fetch workflow run IDs that already have archive log records. + """ + ... + + def get_archived_logs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowArchiveLog]: + """ + Fetch archived workflow logs by time range for restore. + """ + ... + + def get_archived_log_by_run_id( + self, + run_id: str, + ) -> WorkflowArchiveLog | None: + """ + Fetch a workflow archive log by workflow run ID. + """ + ... + + def delete_archive_log_by_run_id( + self, + session: Session, + run_id: str, + ) -> int: + """ + Delete archive log by workflow run ID. + + Used after restoring a workflow run to remove the archive log record, + allowing the run to be archived again if needed. + + Args: + session: Database session + run_id: Workflow run ID + + Returns: + Number of records deleted (0 or 1) + """ + ... + def delete_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -282,6 +334,61 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_pause_records_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowPause]: + """ + Fetch workflow pause records by workflow run ID. + """ + ... + + def get_pause_reason_records_by_run_id( + self, + session: Session, + pause_ids: Sequence[str], + ) -> Sequence[WorkflowPauseReason]: + """ + Fetch workflow pause reason records by pause IDs. + """ + ... + + def get_app_logs_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowAppLog]: + """ + Fetch workflow app logs by workflow run ID. + """ + ... + + def create_archive_logs( + self, + session: Session, + run: WorkflowRun, + app_logs: Sequence[WorkflowAppLog], + trigger_metadata: str | None, + ) -> int: + """ + Create archive log records for a workflow run. + """ + ... + + def get_archived_runs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowRun]: + """ + Return workflow runs that already have archive logs, for cleanup of `workflow_runs`. + """ + ... + def count_runs_with_related( self, runs: Sequence[WorkflowRun], diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 4a7c975d2c..b19cc73bd1 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -351,3 +351,27 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut ) return int(node_executions_count), int(offloads_count) + + @staticmethod + def get_by_run( + session: Session, + run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Fetch node executions for a run using workflow_run_id. + """ + stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def get_offloads_by_execution_ids( + self, + session: Session, + node_execution_ids: Sequence[str], + ) -> Sequence[WorkflowNodeExecutionOffload]: + if not node_execution_ids: + return [] + + stmt = select(WorkflowNodeExecutionOffload).where( + WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids) + ) + return list(session.scalars(stmt)) diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 9d2d06e99f..d5214be042 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -40,14 +40,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.workflow import ( - WorkflowAppLog, - WorkflowPauseReason, - WorkflowRun, -) -from models.workflow import ( - WorkflowPause as WorkflowPauseModel, -) +from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( @@ -369,6 +362,53 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): return session.scalars(stmt).all() + def get_archived_run_ids( + self, + session: Session, + run_ids: Sequence[str], + ) -> set[str]: + if not run_ids: + return set() + + stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids)) + return set(session.scalars(stmt).all()) + + def get_archived_log_by_run_id( + self, + run_id: str, + ) -> WorkflowArchiveLog | None: + with self._session_maker() as session: + stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1) + return session.scalar(stmt) + + def delete_archive_log_by_run_id( + self, + session: Session, + run_id: str, + ) -> int: + stmt = delete(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id) + result = session.execute(stmt) + return cast(CursorResult, result).rowcount or 0 + + def get_pause_records_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowPause]: + stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def get_pause_reason_records_by_run_id( + self, + session: Session, + pause_ids: Sequence[str], + ) -> Sequence[WorkflowPauseReason]: + if not pause_ids: + return [] + + stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) + return list(session.scalars(stmt)) + def delete_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -396,9 +436,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))) app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0 - pause_ids = session.scalars( - select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) - ).all() + pause_stmt = select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids)) + pause_ids = session.scalars(pause_stmt).all() pause_reasons_deleted = 0 pauses_deleted = 0 @@ -407,7 +446,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) ) pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0 - pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids))) + pauses_result = session.execute(delete(WorkflowPause).where(WorkflowPause.id.in_(pause_ids))) pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0 trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0 @@ -427,6 +466,124 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): "pause_reasons": pause_reasons_deleted, } + def get_app_logs_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowAppLog]: + stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def create_archive_logs( + self, + session: Session, + run: WorkflowRun, + app_logs: Sequence[WorkflowAppLog], + trigger_metadata: str | None, + ) -> int: + if not app_logs: + archive_log = WorkflowArchiveLog( + log_id=None, + log_created_at=None, + log_created_from=None, + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_id=run.workflow_id, + workflow_run_id=run.id, + created_by_role=run.created_by_role, + created_by=run.created_by, + run_version=run.version, + run_status=run.status, + run_triggered_from=run.triggered_from, + run_error=run.error, + run_elapsed_time=run.elapsed_time, + run_total_tokens=run.total_tokens, + run_total_steps=run.total_steps, + run_created_at=run.created_at, + run_finished_at=run.finished_at, + run_exceptions_count=run.exceptions_count, + trigger_metadata=trigger_metadata, + ) + session.add(archive_log) + return 1 + + archive_logs = [ + WorkflowArchiveLog( + log_id=app_log.id, + log_created_at=app_log.created_at, + log_created_from=app_log.created_from, + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_id=run.workflow_id, + workflow_run_id=run.id, + created_by_role=run.created_by_role, + created_by=run.created_by, + run_version=run.version, + run_status=run.status, + run_triggered_from=run.triggered_from, + run_error=run.error, + run_elapsed_time=run.elapsed_time, + run_total_tokens=run.total_tokens, + run_total_steps=run.total_steps, + run_created_at=run.created_at, + run_finished_at=run.finished_at, + run_exceptions_count=run.exceptions_count, + trigger_metadata=trigger_metadata, + ) + for app_log in app_logs + ] + session.add_all(archive_logs) + return len(archive_logs) + + def get_archived_runs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowRun]: + """ + Retrieves WorkflowRun records by joining workflow_archive_logs. + + Used to identify runs that are already archived and ready for deletion. + """ + stmt = ( + select(WorkflowRun) + .join(WorkflowArchiveLog, WorkflowArchiveLog.workflow_run_id == WorkflowRun.id) + .where( + WorkflowArchiveLog.run_created_at >= start_date, + WorkflowArchiveLog.run_created_at < end_date, + ) + .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc()) + .limit(limit) + ) + if tenant_ids: + stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids)) + return list(session.scalars(stmt)) + + def get_archived_logs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowArchiveLog]: + # Returns WorkflowArchiveLog rows directly; use this when workflow_runs may be deleted. + stmt = ( + select(WorkflowArchiveLog) + .where( + WorkflowArchiveLog.run_created_at >= start_date, + WorkflowArchiveLog.run_created_at < end_date, + ) + .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc()) + .limit(limit) + ) + if tenant_ids: + stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids)) + return list(session.scalars(stmt)) + def count_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -459,7 +616,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ) pause_ids = session.scalars( - select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) + select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids)) ).all() pauses_count = len(pause_ids) pause_reasons_count = 0 @@ -511,9 +668,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ValueError: If workflow_run_id is invalid or workflow run doesn't exist RuntimeError: If workflow is already paused or in invalid state """ - previous_pause_model_query = select(WorkflowPauseModel).where( - WorkflowPauseModel.workflow_run_id == workflow_run_id - ) + previous_pause_model_query = select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id) with self._session_maker() as session, session.begin(): # Get the workflow run workflow_run = session.get(WorkflowRun, workflow_run_id) @@ -538,7 +693,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Upload the state file # Create the pause record - pause_model = WorkflowPauseModel() + pause_model = WorkflowPause() pause_model.id = str(uuidv7()) pause_model.workflow_id = workflow_run.workflow_id pause_model.workflow_run_id = workflow_run.id @@ -710,13 +865,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): """ with self._session_maker() as session, session.begin(): # Get the pause model by ID - pause_model = session.get(WorkflowPauseModel, pause_entity.id) + pause_model = session.get(WorkflowPause, pause_entity.id) if pause_model is None: raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}") self._delete_pause_model(session, pause_model) @staticmethod - def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel): + def _delete_pause_model(session: Session, pause_model: WorkflowPause): storage.delete(pause_model.state_object_key) # Delete the pause record @@ -751,15 +906,15 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): _limit: int = limit or 1000 pruned_record_ids: list[str] = [] cond = or_( - WorkflowPauseModel.created_at < expiration, + WorkflowPause.created_at < expiration, and_( - WorkflowPauseModel.resumed_at.is_not(null()), - WorkflowPauseModel.resumed_at < resumption_expiration, + WorkflowPause.resumed_at.is_not(null()), + WorkflowPause.resumed_at < resumption_expiration, ), ) # First, collect pause records to delete with their state files # Expired pauses (created before expiration time) - stmt = select(WorkflowPauseModel).where(cond).limit(_limit) + stmt = select(WorkflowPause).where(cond).limit(_limit) with self._session_maker(expire_on_commit=False) as session: # Old resumed pauses (resumed more than resumption_duration ago) @@ -770,7 +925,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Delete state files from storage for pause in pauses_to_delete: with self._session_maker(expire_on_commit=False) as session, session.begin(): - # todo: this issues a separate query for each WorkflowPauseModel record. + # todo: this issues a separate query for each WorkflowPause record. # consider batching this lookup. try: storage.delete(pause.state_object_key) @@ -1022,7 +1177,7 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): def __init__( self, *, - pause_model: WorkflowPauseModel, + pause_model: WorkflowPause, reason_models: Sequence[WorkflowPauseReason], human_input_form: Sequence = (), ) -> None: diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index ebd3745d18..f3dc4cd60b 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -46,6 +46,11 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): return self.session.scalar(query) + def list_by_run_id(self, run_id: str) -> Sequence[WorkflowTriggerLog]: + """List trigger logs for a workflow run.""" + query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id) + return list(self.session.scalars(query).all()) + def get_failed_for_retry( self, tenant_id: str, max_retry_count: int = 3, limit: int = 100 ) -> Sequence[WorkflowTriggerLog]: diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 4a911326fb..d94ae49d91 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -204,7 +204,7 @@ class FeatureService: return knowledge_rate_limit @classmethod - def get_system_features(cls) -> SystemFeatureModel: + def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() cls._fulfill_system_params_from_env(system_features) @@ -214,7 +214,7 @@ class FeatureService: system_features.webapp_auth.enabled = True system_features.enable_change_email = False system_features.plugin_manager.enabled = True - cls._fulfill_params_from_enterprise(system_features) + cls._fulfill_params_from_enterprise(system_features, is_authenticated) if dify_config.MARKETPLACE_ENABLED: system_features.enable_marketplace = True @@ -324,7 +324,7 @@ class FeatureService: features.next_credit_reset_date = billing_info["next_credit_reset_date"] @classmethod - def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel): + def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel, is_authenticated: bool = False): enterprise_info = EnterpriseService.get_info() if "SSOEnforcedForSignin" in enterprise_info: @@ -361,19 +361,14 @@ class FeatureService: ) features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "") - if "License" in enterprise_info: - license_info = enterprise_info["License"] + if is_authenticated and (license_info := enterprise_info.get("License")): + features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) + features.license.expired_at = license_info.get("expiredAt", "") - if "status" in license_info: - features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) - - if "expiredAt" in license_info: - features.license.expired_at = license_info["expiredAt"] - - if "workspaces" in license_info: - features.license.workspaces.enabled = license_info["workspaces"]["enabled"] - features.license.workspaces.limit = license_info["workspaces"]["limit"] - features.license.workspaces.size = license_info["workspaces"]["used"] + if workspaces_info := license_info.get("workspaces"): + features.license.workspaces.enabled = workspaces_info.get("enabled", False) + features.license.workspaces.limit = workspaces_info.get("limit", 0) + features.license.workspaces.size = workspaces_info.get("used", 0) if "PluginInstallationPermission" in enterprise_info: plugin_installation_info = enterprise_info["PluginInstallationPermission"] diff --git a/api/services/retention/workflow_run/__init__.py b/api/services/retention/workflow_run/__init__.py index e69de29bb2..18dd42c91e 100644 --- a/api/services/retention/workflow_run/__init__.py +++ b/api/services/retention/workflow_run/__init__.py @@ -0,0 +1 @@ +"""Workflow run retention services.""" diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py new file mode 100644 index 0000000000..ea5cbb7740 --- /dev/null +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -0,0 +1,531 @@ +""" +Archive Paid Plan Workflow Run Logs Service. + +This service archives workflow run logs for paid plan users older than the configured +retention period (default: 90 days) to S3-compatible storage. + +Archived tables: +- workflow_runs +- workflow_app_logs +- workflow_node_executions +- workflow_node_execution_offload +- workflow_pauses +- workflow_pause_reasons +- workflow_trigger_logs + +""" + +import datetime +import io +import json +import logging +import time +import zipfile +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any + +import click +from sqlalchemy import inspect +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.workflow.enums import WorkflowType +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageNotConfiguredError, + get_archive_storage, +) +from models.workflow import WorkflowAppLog, WorkflowRun +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.billing_service import BillingService +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION + +logger = logging.getLogger(__name__) + + +@dataclass +class TableStats: + """Statistics for a single archived table.""" + + table_name: str + row_count: int + checksum: str + size_bytes: int + + +@dataclass +class ArchiveResult: + """Result of archiving a single workflow run.""" + + run_id: str + tenant_id: str + success: bool + tables: list[TableStats] = field(default_factory=list) + error: str | None = None + elapsed_time: float = 0.0 + + +@dataclass +class ArchiveSummary: + """Summary of the entire archive operation.""" + + total_runs_processed: int = 0 + runs_archived: int = 0 + runs_skipped: int = 0 + runs_failed: int = 0 + total_elapsed_time: float = 0.0 + + +class WorkflowRunArchiver: + """ + Archive workflow run logs for paid plan users. + + Storage Layout: + {tenant_id}/app_id={app_id}/year={YYYY}/month={MM}/workflow_run_id={run_id}/ + └── archive.v1.0.zip + ├── manifest.json + ├── workflow_runs.jsonl + ├── workflow_app_logs.jsonl + ├── workflow_node_executions.jsonl + ├── workflow_node_execution_offload.jsonl + ├── workflow_pauses.jsonl + ├── workflow_pause_reasons.jsonl + └── workflow_trigger_logs.jsonl + """ + + ARCHIVED_TYPE = [ + WorkflowType.WORKFLOW, + WorkflowType.RAG_PIPELINE, + ] + ARCHIVED_TABLES = [ + "workflow_runs", + "workflow_app_logs", + "workflow_node_executions", + "workflow_node_execution_offload", + "workflow_pauses", + "workflow_pause_reasons", + "workflow_trigger_logs", + ] + + start_from: datetime.datetime | None + end_before: datetime.datetime + + def __init__( + self, + days: int = 90, + batch_size: int = 100, + start_from: datetime.datetime | None = None, + end_before: datetime.datetime | None = None, + workers: int = 1, + tenant_ids: Sequence[str] | None = None, + limit: int | None = None, + dry_run: bool = False, + delete_after_archive: bool = False, + workflow_run_repo: APIWorkflowRunRepository | None = None, + ): + """ + Initialize the archiver. + + Args: + days: Archive runs older than this many days + batch_size: Number of runs to process per batch + start_from: Optional start time (inclusive) for archiving + end_before: Optional end time (exclusive) for archiving + workers: Number of concurrent workflow runs to archive + tenant_ids: Optional tenant IDs for grayscale rollout + limit: Maximum number of runs to archive (None for unlimited) + dry_run: If True, only preview without making changes + delete_after_archive: If True, delete runs and related data after archiving + """ + self.days = days + self.batch_size = batch_size + if start_from or end_before: + if start_from is None or end_before is None: + raise ValueError("start_from and end_before must be provided together") + if start_from >= end_before: + raise ValueError("start_from must be earlier than end_before") + self.start_from = start_from.replace(tzinfo=datetime.UTC) + self.end_before = end_before.replace(tzinfo=datetime.UTC) + else: + self.start_from = None + self.end_before = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days) + if workers < 1: + raise ValueError("workers must be at least 1") + self.workers = workers + self.tenant_ids = sorted(set(tenant_ids)) if tenant_ids else [] + self.limit = limit + self.dry_run = dry_run + self.delete_after_archive = delete_after_archive + self.workflow_run_repo = workflow_run_repo + + def run(self) -> ArchiveSummary: + """ + Main archiving loop. + + Returns: + ArchiveSummary with statistics about the operation + """ + summary = ArchiveSummary() + start_time = time.time() + + click.echo( + click.style( + self._build_start_message(), + fg="white", + ) + ) + + # Initialize archive storage (will raise if not configured) + try: + if not self.dry_run: + storage = get_archive_storage() + else: + storage = None + except ArchiveStorageNotConfiguredError as e: + click.echo(click.style(f"Archive storage not configured: {e}", fg="red")) + return summary + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = self._get_workflow_run_repo() + + def _archive_with_session(run: WorkflowRun) -> ArchiveResult: + with session_maker() as session: + return self._archive_run(session, storage, run) + + last_seen: tuple[datetime.datetime, str] | None = None + archived_count = 0 + + with ThreadPoolExecutor(max_workers=self.workers) as executor: + while True: + # Check limit + if self.limit and archived_count >= self.limit: + click.echo(click.style(f"Reached limit of {self.limit} runs", fg="yellow")) + break + + # Fetch batch of runs + runs = self._get_runs_batch(last_seen) + + if not runs: + break + + run_ids = [run.id for run in runs] + with session_maker() as session: + archived_run_ids = repo.get_archived_run_ids(session, run_ids) + + last_seen = (runs[-1].created_at, runs[-1].id) + + # Filter to paid tenants only + tenant_ids = {run.tenant_id for run in runs} + paid_tenants = self._filter_paid_tenants(tenant_ids) + + runs_to_process: list[WorkflowRun] = [] + for run in runs: + summary.total_runs_processed += 1 + + # Skip non-paid tenants + if run.tenant_id not in paid_tenants: + summary.runs_skipped += 1 + continue + + # Skip already archived runs + if run.id in archived_run_ids: + summary.runs_skipped += 1 + continue + + # Check limit + if self.limit and archived_count + len(runs_to_process) >= self.limit: + break + + runs_to_process.append(run) + + if not runs_to_process: + continue + + results = list(executor.map(_archive_with_session, runs_to_process)) + + for run, result in zip(runs_to_process, results): + if result.success: + summary.runs_archived += 1 + archived_count += 1 + click.echo( + click.style( + f"{'[DRY RUN] Would archive' if self.dry_run else 'Archived'} " + f"run {run.id} (tenant={run.tenant_id}, " + f"tables={len(result.tables)}, time={result.elapsed_time:.2f}s)", + fg="green", + ) + ) + else: + summary.runs_failed += 1 + click.echo( + click.style( + f"Failed to archive run {run.id}: {result.error}", + fg="red", + ) + ) + + summary.total_elapsed_time = time.time() - start_time + click.echo( + click.style( + f"{'[DRY RUN] ' if self.dry_run else ''}Archive complete: " + f"processed={summary.total_runs_processed}, archived={summary.runs_archived}, " + f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"time={summary.total_elapsed_time:.2f}s", + fg="white", + ) + ) + + return summary + + def _get_runs_batch( + self, + last_seen: tuple[datetime.datetime, str] | None, + ) -> Sequence[WorkflowRun]: + """Fetch a batch of workflow runs to archive.""" + repo = self._get_workflow_run_repo() + return repo.get_runs_batch_by_time_range( + start_from=self.start_from, + end_before=self.end_before, + last_seen=last_seen, + batch_size=self.batch_size, + run_types=self.ARCHIVED_TYPE, + tenant_ids=self.tenant_ids or None, + ) + + def _build_start_message(self) -> str: + range_desc = f"before {self.end_before.isoformat()}" + if self.start_from: + range_desc = f"between {self.start_from.isoformat()} and {self.end_before.isoformat()}" + return ( + f"{'[DRY RUN] ' if self.dry_run else ''}Starting workflow run archiving " + f"for runs {range_desc} " + f"(batch_size={self.batch_size}, tenant_ids={','.join(self.tenant_ids) or 'all'})" + ) + + def _filter_paid_tenants(self, tenant_ids: set[str]) -> set[str]: + """Filter tenant IDs to only include paid tenants.""" + if not dify_config.BILLING_ENABLED: + # If billing is not enabled, treat all tenants as paid + return tenant_ids + + if not tenant_ids: + return set() + + try: + bulk_info = BillingService.get_plan_bulk_with_cache(list(tenant_ids)) + except Exception: + logger.exception("Failed to fetch billing plans for tenants") + # On error, skip all tenants in this batch + return set() + + # Filter to paid tenants (any plan except SANDBOX) + paid = set() + for tid, info in bulk_info.items(): + if info and info.get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM): + paid.add(tid) + + return paid + + def _archive_run( + self, + session: Session, + storage: ArchiveStorage | None, + run: WorkflowRun, + ) -> ArchiveResult: + """Archive a single workflow run.""" + start_time = time.time() + result = ArchiveResult(run_id=run.id, tenant_id=run.tenant_id, success=False) + + try: + # Extract data from all tables + table_data, app_logs, trigger_metadata = self._extract_data(session, run) + + if self.dry_run: + # In dry run, just report what would be archived + for table_name in self.ARCHIVED_TABLES: + records = table_data.get(table_name, []) + result.tables.append( + TableStats( + table_name=table_name, + row_count=len(records), + checksum="", + size_bytes=0, + ) + ) + result.success = True + else: + if storage is None: + raise ArchiveStorageNotConfiguredError("Archive storage not configured") + archive_key = self._get_archive_key(run) + + # Serialize tables for the archive bundle + table_stats: list[TableStats] = [] + table_payloads: dict[str, bytes] = {} + for table_name in self.ARCHIVED_TABLES: + records = table_data.get(table_name, []) + data = ArchiveStorage.serialize_to_jsonl(records) + table_payloads[table_name] = data + checksum = ArchiveStorage.compute_checksum(data) + + table_stats.append( + TableStats( + table_name=table_name, + row_count=len(records), + checksum=checksum, + size_bytes=len(data), + ) + ) + + # Generate and upload archive bundle + manifest = self._generate_manifest(run, table_stats) + manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8") + archive_data = self._build_archive_bundle(manifest_data, table_payloads) + storage.put_object(archive_key, archive_data) + + repo = self._get_workflow_run_repo() + archived_log_count = repo.create_archive_logs(session, run, app_logs, trigger_metadata) + session.commit() + + deleted_counts = None + if self.delete_after_archive: + deleted_counts = repo.delete_runs_with_related( + [run], + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + + logger.info( + "Archived workflow run %s: tables=%s, archived_logs=%s, deleted=%s", + run.id, + {s.table_name: s.row_count for s in table_stats}, + archived_log_count, + deleted_counts, + ) + + result.tables = table_stats + result.success = True + + except Exception as e: + logger.exception("Failed to archive workflow run %s", run.id) + result.error = str(e) + session.rollback() + + result.elapsed_time = time.time() - start_time + return result + + def _extract_data( + self, + session: Session, + run: WorkflowRun, + ) -> tuple[dict[str, list[dict[str, Any]]], Sequence[WorkflowAppLog], str | None]: + table_data: dict[str, list[dict[str, Any]]] = {} + table_data["workflow_runs"] = [self._row_to_dict(run)] + repo = self._get_workflow_run_repo() + app_logs = repo.get_app_logs_by_run_id(session, run.id) + table_data["workflow_app_logs"] = [self._row_to_dict(row) for row in app_logs] + node_exec_repo = self._get_workflow_node_execution_repo(session) + node_exec_records = node_exec_repo.get_executions_by_workflow_run( + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_run_id=run.id, + ) + node_exec_ids = [record.id for record in node_exec_records] + offload_records = node_exec_repo.get_offloads_by_execution_ids(session, node_exec_ids) + table_data["workflow_node_executions"] = [self._row_to_dict(row) for row in node_exec_records] + table_data["workflow_node_execution_offload"] = [self._row_to_dict(row) for row in offload_records] + repo = self._get_workflow_run_repo() + pause_records = repo.get_pause_records_by_run_id(session, run.id) + pause_ids = [pause.id for pause in pause_records] + pause_reason_records = repo.get_pause_reason_records_by_run_id( + session, + pause_ids, + ) + table_data["workflow_pauses"] = [self._row_to_dict(row) for row in pause_records] + table_data["workflow_pause_reasons"] = [self._row_to_dict(row) for row in pause_reason_records] + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + trigger_records = trigger_repo.list_by_run_id(run.id) + table_data["workflow_trigger_logs"] = [self._row_to_dict(row) for row in trigger_records] + trigger_metadata = trigger_records[0].trigger_metadata if trigger_records else None + return table_data, app_logs, trigger_metadata + + @staticmethod + def _row_to_dict(row: Any) -> dict[str, Any]: + mapper = inspect(row).mapper + return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns} + + def _get_archive_key(self, run: WorkflowRun) -> str: + """Get the storage key for the archive bundle.""" + created_at = run.created_at + prefix = ( + f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/" + f"month={created_at.strftime('%m')}/workflow_run_id={run.id}" + ) + return f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + + def _generate_manifest( + self, + run: WorkflowRun, + table_stats: list[TableStats], + ) -> dict[str, Any]: + """Generate a manifest for the archived workflow run.""" + return { + "schema_version": ARCHIVE_SCHEMA_VERSION, + "workflow_run_id": run.id, + "tenant_id": run.tenant_id, + "app_id": run.app_id, + "workflow_id": run.workflow_id, + "created_at": run.created_at.isoformat(), + "archived_at": datetime.datetime.now(datetime.UTC).isoformat(), + "tables": { + stat.table_name: { + "row_count": stat.row_count, + "checksum": stat.checksum, + "size_bytes": stat.size_bytes, + } + for stat in table_stats + }, + } + + def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes: + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive: + archive.writestr("manifest.json", manifest_data) + for table_name in self.ARCHIVED_TABLES: + data = table_payloads.get(table_name) + if data is None: + raise ValueError(f"Missing archive payload for {table_name}") + archive.writestr(f"{table_name}.jsonl", data) + return buffer.getvalue() + + def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_ids = [run.id for run in runs] + return self._get_workflow_node_execution_repo(session).delete_by_runs(session, run_ids) + + def _get_workflow_node_execution_repo( + self, + session: Session, + ) -> DifyAPIWorkflowNodeExecutionRepository: + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=session.get_bind(), expire_on_commit=False) + return DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + return self.workflow_run_repo diff --git a/api/services/retention/workflow_run/constants.py b/api/services/retention/workflow_run/constants.py new file mode 100644 index 0000000000..162bb4947d --- /dev/null +++ b/api/services/retention/workflow_run/constants.py @@ -0,0 +1,2 @@ +ARCHIVE_SCHEMA_VERSION = "1.0" +ARCHIVE_BUNDLE_NAME = f"archive.v{ARCHIVE_SCHEMA_VERSION}.zip" diff --git a/api/services/retention/workflow_run/delete_archived_workflow_run.py b/api/services/retention/workflow_run/delete_archived_workflow_run.py new file mode 100644 index 0000000000..11873bf1b9 --- /dev/null +++ b/api/services/retention/workflow_run/delete_archived_workflow_run.py @@ -0,0 +1,134 @@ +""" +Delete Archived Workflow Run Service. + +This service deletes archived workflow run data from the database while keeping +archive logs intact. +""" + +import time +from collections.abc import Sequence +from dataclasses import dataclass, field +from datetime import datetime + +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_database import db +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository + + +@dataclass +class DeleteResult: + run_id: str + tenant_id: str + success: bool + deleted_counts: dict[str, int] = field(default_factory=dict) + error: str | None = None + elapsed_time: float = 0.0 + + +class ArchivedWorkflowRunDeletion: + def __init__(self, dry_run: bool = False): + self.dry_run = dry_run + self.workflow_run_repo: APIWorkflowRunRepository | None = None + + def delete_by_run_id(self, run_id: str) -> DeleteResult: + start_time = time.time() + result = DeleteResult(run_id=run_id, tenant_id="", success=False) + + repo = self._get_workflow_run_repo() + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + with session_maker() as session: + run = session.get(WorkflowRun, run_id) + if not run: + result.error = f"Workflow run {run_id} not found" + result.elapsed_time = time.time() - start_time + return result + + result.tenant_id = run.tenant_id + if not repo.get_archived_run_ids(session, [run.id]): + result.error = f"Workflow run {run_id} is not archived" + result.elapsed_time = time.time() - start_time + return result + + result = self._delete_run(run) + result.elapsed_time = time.time() - start_time + return result + + def delete_batch( + self, + tenant_ids: list[str] | None, + start_date: datetime, + end_date: datetime, + limit: int = 100, + ) -> list[DeleteResult]: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + results: list[DeleteResult] = [] + + repo = self._get_workflow_run_repo() + with session_maker() as session: + runs = list( + repo.get_archived_runs_by_time_range( + session=session, + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + ) + for run in runs: + results.append(self._delete_run(run)) + + return results + + def _delete_run(self, run: WorkflowRun) -> DeleteResult: + start_time = time.time() + result = DeleteResult(run_id=run.id, tenant_id=run.tenant_id, success=False) + if self.dry_run: + result.success = True + result.elapsed_time = time.time() - start_time + return result + + repo = self._get_workflow_run_repo() + try: + deleted_counts = repo.delete_runs_with_related( + [run], + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + result.deleted_counts = deleted_counts + result.success = True + except Exception as e: + result.error = str(e) + result.elapsed_time = time.time() - start_time + return result + + @staticmethod + def _delete_trigger_logs(session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + @staticmethod + def _delete_node_executions( + session: Session, + runs: Sequence[WorkflowRun], + ) -> tuple[int, int]: + from repositories.factory import DifyAPIRepositoryFactory + + run_ids = [run.id for run in runs] + repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False) + ) + return repo.delete_by_runs(session, run_ids) + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + from repositories.factory import DifyAPIRepositoryFactory + + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) + return self.workflow_run_repo diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py new file mode 100644 index 0000000000..d4a6e87585 --- /dev/null +++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py @@ -0,0 +1,481 @@ +""" +Restore Archived Workflow Run Service. + +This service restores archived workflow run data from S3-compatible storage +back to the database. +""" + +import io +import json +import logging +import time +import zipfile +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from datetime import datetime +from typing import Any, cast + +import click +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker + +from extensions.ext_database import db +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageNotConfiguredError, + get_archive_storage, +) +from models.trigger import WorkflowTriggerLog +from models.workflow import ( + WorkflowAppLog, + WorkflowArchiveLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPause, + WorkflowPauseReason, + WorkflowRun, +) +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.factory import DifyAPIRepositoryFactory +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME + +logger = logging.getLogger(__name__) + + +# Mapping of table names to SQLAlchemy models +TABLE_MODELS = { + "workflow_runs": WorkflowRun, + "workflow_app_logs": WorkflowAppLog, + "workflow_node_executions": WorkflowNodeExecutionModel, + "workflow_node_execution_offload": WorkflowNodeExecutionOffload, + "workflow_pauses": WorkflowPause, + "workflow_pause_reasons": WorkflowPauseReason, + "workflow_trigger_logs": WorkflowTriggerLog, +} + +SchemaMapper = Callable[[dict[str, Any]], dict[str, Any]] + +SCHEMA_MAPPERS: dict[str, dict[str, SchemaMapper]] = { + "1.0": {}, +} + + +@dataclass +class RestoreResult: + """Result of restoring a single workflow run.""" + + run_id: str + tenant_id: str + success: bool + restored_counts: dict[str, int] + error: str | None = None + elapsed_time: float = 0.0 + + +class WorkflowRunRestore: + """ + Restore archived workflow run data from storage to database. + + This service reads archived data from storage and restores it to the + database tables. It handles idempotency by skipping records that already + exist in the database. + """ + + def __init__(self, dry_run: bool = False, workers: int = 1): + """ + Initialize the restore service. + + Args: + dry_run: If True, only preview without making changes + workers: Number of concurrent workflow runs to restore + """ + self.dry_run = dry_run + if workers < 1: + raise ValueError("workers must be at least 1") + self.workers = workers + self.workflow_run_repo: APIWorkflowRunRepository | None = None + + def _restore_from_run( + self, + run: WorkflowRun | WorkflowArchiveLog, + *, + session_maker: sessionmaker, + ) -> RestoreResult: + start_time = time.time() + run_id = run.workflow_run_id if isinstance(run, WorkflowArchiveLog) else run.id + created_at = run.run_created_at if isinstance(run, WorkflowArchiveLog) else run.created_at + result = RestoreResult( + run_id=run_id, + tenant_id=run.tenant_id, + success=False, + restored_counts={}, + ) + + if not self.dry_run: + click.echo( + click.style( + f"Starting restore for workflow run {run_id} (tenant={run.tenant_id})", + fg="white", + ) + ) + + try: + storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + result.error = str(e) + click.echo(click.style(f"Archive storage not configured: {e}", fg="red")) + result.elapsed_time = time.time() - start_time + return result + + prefix = ( + f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/" + f"month={created_at.strftime('%m')}/workflow_run_id={run_id}" + ) + archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + try: + archive_data = storage.get_object(archive_key) + except FileNotFoundError: + result.error = f"Archive bundle not found: {archive_key}" + click.echo(click.style(result.error, fg="red")) + result.elapsed_time = time.time() - start_time + return result + + with session_maker() as session: + try: + with zipfile.ZipFile(io.BytesIO(archive_data), mode="r") as archive: + try: + manifest = self._load_manifest_from_zip(archive) + except ValueError as e: + result.error = f"Archive bundle invalid: {e}" + click.echo(click.style(result.error, fg="red")) + return result + + tables = manifest.get("tables", {}) + schema_version = self._get_schema_version(manifest) + for table_name, info in tables.items(): + row_count = info.get("row_count", 0) + if row_count == 0: + result.restored_counts[table_name] = 0 + continue + + if self.dry_run: + result.restored_counts[table_name] = row_count + continue + + member_path = f"{table_name}.jsonl" + try: + data = archive.read(member_path) + except KeyError: + click.echo( + click.style( + f" Warning: Table data not found in archive: {member_path}", + fg="yellow", + ) + ) + result.restored_counts[table_name] = 0 + continue + + records = ArchiveStorage.deserialize_from_jsonl(data) + restored = self._restore_table_records( + session, + table_name, + records, + schema_version=schema_version, + ) + result.restored_counts[table_name] = restored + if not self.dry_run: + click.echo( + click.style( + f" Restored {restored}/{len(records)} records to {table_name}", + fg="white", + ) + ) + + # Verify row counts match manifest + manifest_total = sum(info.get("row_count", 0) for info in tables.values()) + restored_total = sum(result.restored_counts.values()) + + if not self.dry_run: + # Note: restored count might be less than manifest count if records already exist + logger.info( + "Restore verification: manifest_total=%d, restored_total=%d", + manifest_total, + restored_total, + ) + + # Delete the archive log record after successful restore + repo = self._get_workflow_run_repo() + repo.delete_archive_log_by_run_id(session, run_id) + + session.commit() + + result.success = True + if not self.dry_run: + click.echo( + click.style( + f"Completed restore for workflow run {run_id}: restored={result.restored_counts}", + fg="green", + ) + ) + + except Exception as e: + logger.exception("Failed to restore workflow run %s", run_id) + result.error = str(e) + session.rollback() + click.echo(click.style(f"Restore failed: {e}", fg="red")) + + result.elapsed_time = time.time() - start_time + return result + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) + return self.workflow_run_repo + + @staticmethod + def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]: + try: + data = archive.read("manifest.json") + except KeyError as e: + raise ValueError("manifest.json missing from archive bundle") from e + return json.loads(data.decode("utf-8")) + + def _restore_table_records( + self, + session: Session, + table_name: str, + records: list[dict[str, Any]], + *, + schema_version: str, + ) -> int: + """ + Restore records to a table. + + Uses INSERT ... ON CONFLICT DO NOTHING for idempotency. + + Args: + session: Database session + table_name: Name of the table + records: List of record dictionaries + schema_version: Archived schema version from manifest + + Returns: + Number of records actually inserted + """ + if not records: + return 0 + + model = TABLE_MODELS.get(table_name) + if not model: + logger.warning("Unknown table: %s", table_name) + return 0 + + column_names, required_columns, non_nullable_with_default = self._get_model_column_info(model) + unknown_fields: set[str] = set() + + # Apply schema mapping, filter to current columns, then convert datetimes + converted_records = [] + for record in records: + mapped = self._apply_schema_mapping(table_name, schema_version, record) + unknown_fields.update(set(mapped.keys()) - column_names) + filtered = {key: value for key, value in mapped.items() if key in column_names} + for key in non_nullable_with_default: + if key in filtered and filtered[key] is None: + filtered.pop(key) + missing_required = [key for key in required_columns if key not in filtered or filtered.get(key) is None] + if missing_required: + missing_cols = ", ".join(sorted(missing_required)) + raise ValueError( + f"Missing required columns for {table_name} (schema_version={schema_version}): {missing_cols}" + ) + converted = self._convert_datetime_fields(filtered, model) + converted_records.append(converted) + if unknown_fields: + logger.warning( + "Dropped unknown columns for %s (schema_version=%s): %s", + table_name, + schema_version, + ", ".join(sorted(unknown_fields)), + ) + + # Use INSERT ... ON CONFLICT DO NOTHING for idempotency + stmt = pg_insert(model).values(converted_records) + stmt = stmt.on_conflict_do_nothing(index_elements=["id"]) + + result = session.execute(stmt) + return cast(CursorResult, result).rowcount or 0 + + def _convert_datetime_fields( + self, + record: dict[str, Any], + model: type[DeclarativeBase] | Any, + ) -> dict[str, Any]: + """Convert ISO datetime strings to datetime objects.""" + from sqlalchemy import DateTime + + result = dict(record) + + for column in model.__table__.columns: + if isinstance(column.type, DateTime): + value = result.get(column.key) + if isinstance(value, str): + try: + result[column.key] = datetime.fromisoformat(value) + except ValueError: + pass + + return result + + def _get_schema_version(self, manifest: dict[str, Any]) -> str: + schema_version = manifest.get("schema_version") + if not schema_version: + logger.warning("Manifest missing schema_version; defaulting to 1.0") + schema_version = "1.0" + schema_version = str(schema_version) + if schema_version not in SCHEMA_MAPPERS: + raise ValueError(f"Unsupported schema_version {schema_version}. Add a mapping before restoring.") + return schema_version + + def _apply_schema_mapping( + self, + table_name: str, + schema_version: str, + record: dict[str, Any], + ) -> dict[str, Any]: + # Keep hook for forward/backward compatibility when schema evolves. + mapper = SCHEMA_MAPPERS.get(schema_version, {}).get(table_name) + if mapper is None: + return dict(record) + return mapper(record) + + def _get_model_column_info( + self, + model: type[DeclarativeBase] | Any, + ) -> tuple[set[str], set[str], set[str]]: + columns = list(model.__table__.columns) + column_names = {column.key for column in columns} + required_columns = { + column.key + for column in columns + if not column.nullable + and column.default is None + and column.server_default is None + and not column.autoincrement + } + non_nullable_with_default = { + column.key + for column in columns + if not column.nullable + and (column.default is not None or column.server_default is not None or column.autoincrement) + } + return column_names, required_columns, non_nullable_with_default + + def restore_batch( + self, + tenant_ids: list[str] | None, + start_date: datetime, + end_date: datetime, + limit: int = 100, + ) -> list[RestoreResult]: + """ + Restore multiple workflow runs by time range. + + Args: + tenant_ids: Optional tenant IDs + start_date: Start date filter + end_date: End date filter + limit: Maximum number of runs to restore (default: 100) + + Returns: + List of RestoreResult objects + """ + results: list[RestoreResult] = [] + if tenant_ids is not None and not tenant_ids: + return results + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = self._get_workflow_run_repo() + + with session_maker() as session: + archive_logs = repo.get_archived_logs_by_time_range( + session=session, + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + + click.echo( + click.style( + f"Found {len(archive_logs)} archived workflow runs to restore", + fg="white", + ) + ) + + def _restore_with_session(archive_log: WorkflowArchiveLog) -> RestoreResult: + return self._restore_from_run( + archive_log, + session_maker=session_maker, + ) + + with ThreadPoolExecutor(max_workers=self.workers) as executor: + results = list(executor.map(_restore_with_session, archive_logs)) + + total_counts: dict[str, int] = {} + for result in results: + for table_name, count in result.restored_counts.items(): + total_counts[table_name] = total_counts.get(table_name, 0) + count + success_count = sum(1 for result in results if result.success) + + if self.dry_run: + click.echo( + click.style( + f"[DRY RUN] Would restore {len(results)} workflow runs: totals={total_counts}", + fg="yellow", + ) + ) + else: + click.echo( + click.style( + f"Restored {success_count}/{len(results)} workflow runs: totals={total_counts}", + fg="green", + ) + ) + + return results + + def restore_by_run_id( + self, + run_id: str, + ) -> RestoreResult: + """ + Restore a single workflow run by run ID. + """ + repo = self._get_workflow_run_repo() + archive_log = repo.get_archived_log_by_run_id(run_id) + + if not archive_log: + click.echo(click.style(f"Workflow run archive {run_id} not found", fg="red")) + return RestoreResult( + run_id=run_id, + tenant_id="", + success=False, + restored_counts={}, + error=f"Workflow run archive {run_id} not found", + ) + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + result = self._restore_from_run(archive_log, session_maker=session_maker) + if self.dry_run and result.success: + click.echo( + click.style( + f"[DRY RUN] Would restore workflow run {run_id}: totals={result.restored_counts}", + fg="yellow", + ) + ) + return result diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 8574d30255..efc76c33bc 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from core.workflow.enums import WorkflowExecutionStatus -from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun +from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog from services.plugin.plugin_service import PluginService @@ -173,7 +173,80 @@ class WorkflowAppService: "data": items, } - def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]: + def get_paginate_workflow_archive_logs( + self, + *, + session: Session, + app_model: App, + page: int = 1, + limit: int = 20, + ): + """ + Get paginate workflow archive logs using SQLAlchemy 2.0 style. + """ + stmt = select(WorkflowArchiveLog).where( + WorkflowArchiveLog.tenant_id == app_model.tenant_id, + WorkflowArchiveLog.app_id == app_model.id, + WorkflowArchiveLog.log_id.isnot(None), + ) + + stmt = stmt.order_by(WorkflowArchiveLog.run_created_at.desc()) + + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = session.scalar(count_stmt) or 0 + + offset_stmt = stmt.offset((page - 1) * limit).limit(limit) + + logs = list(session.scalars(offset_stmt).all()) + account_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.ACCOUNT} + end_user_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.END_USER} + + accounts_by_id = {} + if account_ids: + accounts_by_id = { + account.id: account + for account in session.scalars(select(Account).where(Account.id.in_(account_ids))).all() + } + + end_users_by_id = {} + if end_user_ids: + end_users_by_id = { + end_user.id: end_user + for end_user in session.scalars(select(EndUser).where(EndUser.id.in_(end_user_ids))).all() + } + + items = [] + for log in logs: + if log.created_by_role == CreatorUserRole.ACCOUNT: + created_by_account = accounts_by_id.get(log.created_by) + created_by_end_user = None + elif log.created_by_role == CreatorUserRole.END_USER: + created_by_account = None + created_by_end_user = end_users_by_id.get(log.created_by) + else: + created_by_account = None + created_by_end_user = None + + items.append( + { + "id": log.id, + "workflow_run": log.workflow_run_summary, + "trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, log.trigger_metadata), + "created_by_account": created_by_account, + "created_by_end_user": created_by_end_user, + "created_at": log.log_created_at, + } + ) + + return { + "page": page, + "limit": limit, + "total": total, + "has_more": total > page * limit, + "data": items, + } + + def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]: metadata: dict[str, Any] | None = self._safe_json_loads(meta_val) if not metadata: return {} diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 4e5fb08870..817249845a 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -11,8 +11,10 @@ from sqlalchemy.engine import CursorResult from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker +from configs import dify_config from core.db.session_factory import session_factory from extensions.ext_database import db +from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from models import ( ApiToken, AppAnnotationHitHistory, @@ -43,6 +45,7 @@ from models.workflow import ( ConversationVariable, Workflow, WorkflowAppLog, + WorkflowArchiveLog, ) from repositories.factory import DifyAPIRepositoryFactory @@ -67,6 +70,9 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_app_workflow_runs(tenant_id, app_id) _delete_app_workflow_node_executions(tenant_id, app_id) _delete_app_workflow_app_logs(tenant_id, app_id) + if dify_config.BILLING_ENABLED and dify_config.ARCHIVE_STORAGE_ENABLED: + _delete_app_workflow_archive_logs(tenant_id, app_id) + _delete_archived_workflow_run_files(tenant_id, app_id) _delete_app_conversations(tenant_id, app_id) _delete_app_messages(tenant_id, app_id) _delete_workflow_tool_providers(tenant_id, app_id) @@ -252,6 +258,45 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): ) +def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): + def del_workflow_archive_log(workflow_archive_log_id: str): + db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( + synchronize_session=False + ) + + _delete_records( + """select id from workflow_archive_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_workflow_archive_log, + "workflow archive log", + ) + + +def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): + prefix = f"{tenant_id}/app_id={app_id}/" + try: + archive_storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + logger.info("Archive storage not configured, skipping archive file cleanup: %s", e) + return + + try: + keys = archive_storage.list_objects(prefix) + except Exception: + logger.exception("Failed to list archive files for app %s", app_id) + return + + deleted = 0 + for key in keys: + try: + archive_storage.delete_object(key) + deleted += 1 + except Exception: + logger.exception("Failed to delete archive object %s", key) + + logger.info("Deleted %s archive objects for app %s", deleted, app_id) + + def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(session, conversation_id: str): session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 9b0bd6275b..1a9d69b2d2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -5,13 +5,13 @@ import pytest from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import NodeRunResult from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index d814da8ec7..1bcac3b5fe 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -5,11 +5,11 @@ from urllib.parse import urlencode import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.http_request.node import HttpRequestNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d268c5da22..c361bfcc6f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -5,13 +5,13 @@ from collections.abc import Generator from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import StreamCompletedEvent from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 654db59bec..7445699a86 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -4,11 +4,11 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.model_runtime.entities import AssistantPromptMessage from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 3bcb9a3a34..bc03ce1b96 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -4,10 +4,10 @@ import uuid import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index d666f0ebe2..cfbef52c93 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -3,12 +3,12 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import StreamCompletedEvent -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index 40380b09d2..bd2fd14ffa 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -4,7 +4,13 @@ import pytest from faker import Faker from enums.cloud_plan import CloudPlan -from services.feature_service import FeatureModel, FeatureService, KnowledgeRateLimitModel, SystemFeatureModel +from services.feature_service import ( + FeatureModel, + FeatureService, + KnowledgeRateLimitModel, + LicenseStatus, + SystemFeatureModel, +) class TestFeatureService: @@ -274,7 +280,7 @@ class TestFeatureService: mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 # Act: Execute the method under test - result = FeatureService.get_system_features() + result = FeatureService.get_system_features(is_authenticated=True) # Assert: Verify the expected outcomes assert result is not None @@ -324,6 +330,61 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval for an unauthenticated user. + + This test verifies that: + - The response payload is minimized (e.g., verbose license details are excluded). + - Essential UI configuration (Branding, SSO, Marketplace) remains available. + - The response structure adheres to the public schema for unauthenticated clients. + """ + # Arrange: Setup test data with exact same config as success test + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = True + mock_config.ENABLE_EMAIL_CODE_LOGIN = True + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + # Act: Execute with is_authenticated=False + result = FeatureService.get_system_features(is_authenticated=False) + + # Assert: Basic structure + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # --- 1. Verify Response Payload Optimization (Data Minimization) --- + # Ensure only essential UI flags are returned to unauthenticated clients + # to keep the payload lightweight and adhere to architectural boundaries. + assert result.license.status == LicenseStatus.NONE + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + assert result.license.workspaces.limit == 0 + assert result.license.workspaces.size == 0 + + # --- 2. Verify Public UI Configuration Availability --- + # Ensure that data required for frontend rendering remains accessible. + + # Branding should match the mock data + assert result.branding.enabled is True + assert result.branding.application_title == "Test Enterprise" + assert result.branding.login_page_logo == "https://example.com/logo.png" + + # SSO settings should be visible for login page rendering + assert result.sso_enforced_for_signin is True + assert result.sso_enforced_for_signin_protocol == "saml" + + # General auth settings should be visible + assert result.enable_email_code_login is True + + # Marketplace should be visible + assert result.enable_marketplace is True + def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies): """ Test system features retrieval with basic configuration (no enterprise). @@ -1031,7 +1092,7 @@ class TestFeatureService: } # Act: Execute the method under test - result = FeatureService.get_system_features() + result = FeatureService.get_system_features(is_authenticated=True) # Assert: Verify the expected outcomes assert result is not None @@ -1400,7 +1461,7 @@ class TestFeatureService: } # Act: Execute the method under test - result = FeatureService.get_system_features() + result = FeatureService.get_system_features(is_authenticated=True) # Assert: Verify the expected outcomes assert result is not None diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index 63466cfb5e..8dd669e17f 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -1,6 +1,8 @@ """Tests for execution context module.""" import contextvars +import threading +from contextlib import contextmanager from typing import Any from unittest.mock import MagicMock @@ -149,6 +151,54 @@ class TestExecutionContext: assert ctx.user == user + def test_thread_safe_context_manager(self): + """Test shared ExecutionContext works across threads without token mismatch.""" + test_var = contextvars.ContextVar("thread_safe_test_var") + + class TrackingAppContext(AppContext): + def get_config(self, key: str, default: Any = None) -> Any: + return default + + def get_extension(self, name: str) -> Any: + return None + + @contextmanager + def enter(self): + token = test_var.set(threading.get_ident()) + try: + yield + finally: + test_var.reset(token) + + ctx = ExecutionContext(app_context=TrackingAppContext()) + errors: list[Exception] = [] + barrier = threading.Barrier(2) + + def worker(): + try: + for _ in range(20): + with ctx: + try: + barrier.wait() + barrier.wait() + except threading.BrokenBarrierError: + return + except Exception as exc: + errors.append(exc) + try: + barrier.abort() + except Exception: + pass + + t1 = threading.Thread(target=worker) + t2 = threading.Thread(target=worker) + t1.start() + t2.start() + t1.join(timeout=5) + t2.join(timeout=5) + + assert not errors + class TestIExecutionContextProtocol: """Test IExecutionContext protocol.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 6e9a432745..170445225b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -7,9 +7,9 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.enums import NodeType from core.workflow.nodes.base.node import Node -from core.workflow.nodes.node_factory import DifyNodeFactory from .test_mock_nodes import ( MockAgentNode, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index b76fe42fce..e8cd665107 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -13,6 +13,7 @@ from unittest.mock import patch from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.graph import Graph @@ -26,7 +27,6 @@ from core.workflow.graph_events import ( ) from core.workflow.node_events import NodeRunResult, StreamCompletedEvent from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 08f7b00a33..10ac1206fb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,6 +19,7 @@ from functools import lru_cache from pathlib import Path from typing import Any +from core.app.workflow.node_factory import DifyNodeFactory from core.tools.utils.yaml_utils import _load_yaml_file from core.variables import ( ArrayNumberVariable, @@ -38,7 +39,6 @@ from core.workflow.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 98d9560e64..1e95ec1970 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -3,11 +3,11 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index dc7175f964..d700888c2f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, Mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.file import File, FileTransferMethod, FileType from core.variables import ArrayFileSegment from core.workflow.entities import GraphInitParams @@ -12,7 +13,6 @@ from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 1df75380af..d4b7a017f9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -3,11 +3,11 @@ import uuid from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.variables import ArrayStringVariable, StringVariable from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_events.node import NodeRunSucceededEvent -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 353d56fe25..b08f9c37b4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -3,10 +3,10 @@ import uuid from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.variables import ArrayStringVariable from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation from core.workflow.runtime import GraphRuntimeState, VariablePool diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py index 7cdb2328f2..078ec5f6ab 100644 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ b/api/tests/unit_tests/core/workflow/test_enums.py @@ -30,3 +30,12 @@ class TestWorkflowExecutionStatus: for status in non_ended_statuses: assert not status.is_ended(), f"{status} should not be considered ended" + + def test_ended_values(self): + """Test ended_values returns the expected status values.""" + assert set(WorkflowExecutionStatus.ended_values()) == { + WorkflowExecutionStatus.SUCCEEDED.value, + WorkflowExecutionStatus.FAILED.value, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, + WorkflowExecutionStatus.STOPPED.value, + } diff --git a/api/tests/unit_tests/libs/test_archive_storage.py b/api/tests/unit_tests/libs/test_archive_storage.py index 697760e33a..de3c9c4737 100644 --- a/api/tests/unit_tests/libs/test_archive_storage.py +++ b/api/tests/unit_tests/libs/test_archive_storage.py @@ -37,6 +37,20 @@ def _client_error(code: str) -> ClientError: def _mock_client(monkeypatch): client = MagicMock() client.head_bucket.return_value = None + # Configure put_object to return a proper ETag that matches the MD5 hash + # The ETag format is typically the MD5 hash wrapped in quotes + + def mock_put_object(**kwargs): + md5_hash = kwargs.get("Body", b"") + if isinstance(md5_hash, bytes): + md5_hash = hashlib.md5(md5_hash).hexdigest() + else: + md5_hash = hashlib.md5(md5_hash.encode()).hexdigest() + response = MagicMock() + response.get.return_value = f'"{md5_hash}"' + return response + + client.put_object.side_effect = mock_put_object boto_client = MagicMock(return_value=client) monkeypatch.setattr(storage_module.boto3, "client", boto_client) return client, boto_client @@ -254,8 +268,8 @@ def test_serialization_roundtrip(): {"id": "2", "value": 123}, ] - data = ArchiveStorage.serialize_to_jsonl_gz(records) - decoded = ArchiveStorage.deserialize_from_jsonl_gz(data) + data = ArchiveStorage.serialize_to_jsonl(records) + decoded = ArchiveStorage.deserialize_from_jsonl(data) assert decoded[0]["id"] == "1" assert decoded[0]["payload"]["nested"] == "value" diff --git a/api/tests/unit_tests/services/test_archive_workflow_run_logs.py b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py new file mode 100644 index 0000000000..ef62dacd6b --- /dev/null +++ b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py @@ -0,0 +1,54 @@ +""" +Unit tests for workflow run archiving functionality. + +This module contains tests for: +- Archive service +- Rollback service +""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME + + +class TestWorkflowRunArchiver: + """Tests for the WorkflowRunArchiver class.""" + + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage") + def test_archiver_initialization(self, mock_get_storage, mock_config): + """Test archiver can be initialized with various options.""" + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + mock_config.BILLING_ENABLED = False + + archiver = WorkflowRunArchiver( + days=90, + batch_size=100, + tenant_ids=["test-tenant"], + limit=50, + dry_run=True, + ) + + assert archiver.days == 90 + assert archiver.batch_size == 100 + assert archiver.tenant_ids == ["test-tenant"] + assert archiver.limit == 50 + assert archiver.dry_run is True + + def test_get_archive_key(self): + """Test archive key generation.""" + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + archiver = WorkflowRunArchiver.__new__(WorkflowRunArchiver) + + mock_run = MagicMock() + mock_run.tenant_id = "tenant-123" + mock_run.app_id = "app-999" + mock_run.id = "run-456" + mock_run.created_at = datetime(2024, 1, 15, 12, 0, 0) + + key = archiver._get_archive_key(mock_run) + + assert key == f"tenant-123/app_id=app-999/year=2024/month=01/workflow_run_id=run-456/{ARCHIVE_BUNDLE_NAME}" diff --git a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py new file mode 100644 index 0000000000..2c9d946ea6 --- /dev/null +++ b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py @@ -0,0 +1,180 @@ +""" +Unit tests for archived workflow run deletion service. +""" + +from unittest.mock import MagicMock, patch + + +class TestArchivedWorkflowRunDeletion: + def test_delete_by_run_id_returns_error_when_run_missing(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + session = MagicMock() + session.get.return_value = None + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is False + assert result.error == "Workflow run run-1 not found" + repo.get_archived_run_ids.assert_not_called() + + def test_delete_by_run_id_returns_error_when_not_archived(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + repo.get_archived_run_ids.return_value = set() + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + session = MagicMock() + session.get.return_value = run + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object(deleter, "_delete_run") as mock_delete_run, + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is False + assert result.error == "Workflow run run-1 is not archived" + mock_delete_run.assert_not_called() + + def test_delete_by_run_id_calls_delete_run(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + repo.get_archived_run_ids.return_value = {"run-1"} + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + session = MagicMock() + session.get.return_value = run + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object(deleter, "_delete_run", return_value=MagicMock(success=True)) as mock_delete_run, + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is True + mock_delete_run.assert_called_once_with(run) + + def test_delete_batch_uses_repo(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + run1 = MagicMock() + run1.id = "run-1" + run1.tenant_id = "tenant-1" + run2 = MagicMock() + run2.id = "run-2" + run2.tenant_id = "tenant-1" + repo.get_archived_runs_by_time_range.return_value = [run1, run2] + + session = MagicMock() + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + start_date = MagicMock() + end_date = MagicMock() + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object( + deleter, "_delete_run", side_effect=[MagicMock(success=True), MagicMock(success=True)] + ) as mock_delete_run, + ): + results = deleter.delete_batch( + tenant_ids=["tenant-1"], + start_date=start_date, + end_date=end_date, + limit=2, + ) + + assert len(results) == 2 + repo.get_archived_runs_by_time_range.assert_called_once_with( + session=session, + tenant_ids=["tenant-1"], + start_date=start_date, + end_date=end_date, + limit=2, + ) + assert mock_delete_run.call_count == 2 + + def test_delete_run_dry_run(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion(dry_run=True) + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: + result = deleter._delete_run(run) + + assert result.success is True + mock_get_repo.assert_not_called() + + def test_delete_run_calls_repo(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + repo = MagicMock() + repo.delete_runs_with_related.return_value = {"runs": 1} + + with patch.object(deleter, "_get_workflow_run_repo", return_value=repo): + result = deleter._delete_run(run) + + assert result.success is True + assert result.deleted_counts == {"runs": 1} + repo.delete_runs_with_related.assert_called_once() diff --git a/api/tests/unit_tests/services/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py new file mode 100644 index 0000000000..68aa8c0fe1 --- /dev/null +++ b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py @@ -0,0 +1,65 @@ +""" +Unit tests for workflow run restore functionality. +""" + +from datetime import datetime +from unittest.mock import MagicMock + + +class TestWorkflowRunRestore: + """Tests for the WorkflowRunRestore class.""" + + def test_restore_initialization(self): + """Restore service should respect dry_run flag.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + restore = WorkflowRunRestore(dry_run=True) + + assert restore.dry_run is True + + def test_convert_datetime_fields(self): + """ISO datetime strings should be converted to datetime objects.""" + from models.workflow import WorkflowRun + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + record = { + "id": "test-id", + "created_at": "2024-01-01T12:00:00", + "finished_at": "2024-01-01T12:05:00", + "name": "test", + } + + restore = WorkflowRunRestore() + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["created_at"].month == 1 + assert result["name"] == "test" + + def test_restore_table_records_returns_rowcount(self): + """Restore should return inserted rowcount.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + session = MagicMock() + session.execute.return_value = MagicMock(rowcount=2) + + restore = WorkflowRunRestore() + records = [{"id": "p1", "workflow_run_id": "r1", "created_at": "2024-01-01T00:00:00"}] + + restored = restore._restore_table_records(session, "workflow_pauses", records, schema_version="1.0") + + assert restored == 2 + session.execute.assert_called_once() + + def test_restore_table_records_unknown_table(self): + """Unknown table names should be ignored gracefully.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + session = MagicMock() + + restore = WorkflowRunRestore() + restored = restore._restore_table_records(session, "unknown_table", [{"id": "x1"}], schema_version="1.0") + + assert restored == 0 + session.execute.assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index ccf43591f0..a14bbb01d0 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,7 +2,11 @@ from unittest.mock import ANY, MagicMock, call, patch import pytest +from libs.archive_storage import ArchiveStorageNotConfiguredError +from models.workflow import WorkflowArchiveLog from tasks.remove_app_and_related_data_task import ( + _delete_app_workflow_archive_logs, + _delete_archived_workflow_run_files, _delete_draft_variable_offload_data, _delete_draft_variables, delete_draft_variables_batch, @@ -324,3 +328,68 @@ class TestDeleteDraftVariableOffloadData: # Verify error was logged mock_logging.exception.assert_called_once_with("Error deleting draft variable offload data:") + + +class TestDeleteWorkflowArchiveLogs: + @patch("tasks.remove_app_and_related_data_task._delete_records") + @patch("tasks.remove_app_and_related_data_task.db") + def test_delete_app_workflow_archive_logs_calls_delete_records(self, mock_db, mock_delete_records): + tenant_id = "tenant-1" + app_id = "app-1" + + _delete_app_workflow_archive_logs(tenant_id, app_id) + + mock_delete_records.assert_called_once() + query_sql, params, delete_func, name = mock_delete_records.call_args[0] + assert "workflow_archive_logs" in query_sql + assert params == {"tenant_id": tenant_id, "app_id": app_id} + assert name == "workflow archive log" + + mock_query = MagicMock() + mock_delete_query = MagicMock() + mock_query.where.return_value = mock_delete_query + mock_db.session.query.return_value = mock_query + + delete_func("log-1") + + mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) + mock_query.where.assert_called_once() + mock_delete_query.delete.assert_called_once_with(synchronize_session=False) + + +class TestDeleteArchivedWorkflowRunFiles: + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_not_configured(self, mock_logger, mock_get_storage): + mock_get_storage.side_effect = ArchiveStorageNotConfiguredError("missing config") + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + assert mock_logger.info.call_count == 1 + assert "Archive storage not configured" in mock_logger.info.call_args[0][0] + + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_list_failure(self, mock_logger, mock_get_storage): + storage = MagicMock() + storage.list_objects.side_effect = Exception("list failed") + mock_get_storage.return_value = storage + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + storage.list_objects.assert_called_once_with("tenant-1/app_id=app-1/") + storage.delete_object.assert_not_called() + mock_logger.exception.assert_called_once_with("Failed to list archive files for app %s", "app-1") + + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_success(self, mock_logger, mock_get_storage): + storage = MagicMock() + storage.list_objects.return_value = ["key-1", "key-2"] + mock_get_storage.return_value = storage + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + storage.list_objects.assert_called_once_with("tenant-1/app_id=app-1/") + storage.delete_object.assert_has_calls([call("key-1"), call("key-2")], any_order=False) + mock_logger.info.assert_called_with("Deleted %s archive objects for app %s", 2, "app-1") diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index 28489a6714..6d5eb1ef95 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -48,7 +48,7 @@ const CSVUploader: FC = ({ setDragging(false) if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length > 1) { notify({ type: 'error', message: t('stepOne.uploader.validation.count', { ns: 'datasetCreation' }) }) return diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 1a8810f7cd..4d9a4e480f 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -271,9 +271,9 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar )} {hasVar && ( -
+
{ onPromptVariablesChange?.(list.map(item => item.variable)) }} handle=".handle" diff --git a/web/app/components/app/configuration/config-var/var-item.tsx b/web/app/components/app/configuration/config-var/var-item.tsx index 1fc21e3d33..b26249dac8 100644 --- a/web/app/components/app/configuration/config-var/var-item.tsx +++ b/web/app/components/app/configuration/config-var/var-item.tsx @@ -39,7 +39,7 @@ const VarItem: FC = ({ const [isDeleting, setIsDeleting] = useState(false) return ( -
+
{canDrag && ( diff --git a/web/app/components/app/configuration/config-vision/index.tsx b/web/app/components/app/configuration/config-vision/index.tsx index bc313b9ac1..481e6b5ab6 100644 --- a/web/app/components/app/configuration/config-vision/index.tsx +++ b/web/app/components/app/configuration/config-vision/index.tsx @@ -1,5 +1,6 @@ 'use client' import type { FC } from 'react' +import { noop } from 'es-toolkit/function' import { produce } from 'immer' import * as React from 'react' import { useCallback } from 'react' @@ -10,14 +11,17 @@ import { useFeatures, useFeaturesStore } from '@/app/components/base/features/ho import { Vision } from '@/app/components/base/icons/src/vender/features' import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' +import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card' import { SupportUploadFileTypes } from '@/app/components/workflow/types' // import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card' import ConfigContext from '@/context/debug-configuration' +import { Resolution } from '@/types/app' +import { cn } from '@/utils/classnames' import ParamConfig from './param-config' const ConfigVision: FC = () => { const { t } = useTranslation() - const { isShowVisionConfig, isAllowVideoUpload } = useContext(ConfigContext) + const { isShowVisionConfig, isAllowVideoUpload, readonly } = useContext(ConfigContext) const file = useFeatures(s => s.features.file) const featuresStore = useFeaturesStore() @@ -54,7 +58,7 @@ const ConfigVision: FC = () => { setFeatures(newFeatures) }, [featuresStore, isAllowVideoUpload]) - if (!isShowVisionConfig) + if (!isShowVisionConfig || (readonly && !isImageEnabled)) return null return ( @@ -75,37 +79,55 @@ const ConfigVision: FC = () => { />
- {/*
-
{t('appDebug.vision.visionSettings.resolution')}
- - {t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
- } - /> -
*/} - {/*
- handleChange(Resolution.high)} - /> - handleChange(Resolution.low)} - /> -
*/} - -
- + {readonly + ? ( + <> +
+
{t('vision.visionSettings.resolution', { ns: 'appDebug' })}
+ + {t('vision.visionSettings.resolutionTooltip', { ns: 'appDebug' }).split('\n').map(item => ( +
{item}
+ ))} +
+ )} + /> +
+
+ + +
+ + ) + : ( + <> + +
+ + + )} +
) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 7139ba66e0..486c0a8ac9 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -40,7 +40,7 @@ type AgentToolWithMoreInfo = AgentTool & { icon: any, collection?: Collection } const AgentTools: FC = () => { const { t } = useTranslation() const [isShowChooseTool, setIsShowChooseTool] = useState(false) - const { modelConfig, setModelConfig } = useContext(ConfigContext) + const { readonly, modelConfig, setModelConfig } = useContext(ConfigContext) const { data: buildInTools } = useAllBuiltInTools() const { data: customTools } = useAllCustomTools() const { data: workflowTools } = useAllWorkflowTools() @@ -168,10 +168,10 @@ const AgentTools: FC = () => { {tools.filter(item => !!item.enabled).length} / {tools.length} -  +   {t('agent.tools.enabled', { ns: 'appDebug' })} - {tools.length < MAX_TOOLS_NUM && ( + {tools.length < MAX_TOOLS_NUM && !readonly && ( <>
{ )} > -
+
{tools.map((item: AgentTool & { icon: any, collection?: Collection }, index) => (
{ > {getProviderShowName(item)} {item.tool_label} - {!item.isDeleted && ( + {!item.isDeleted && !readonly && ( @@ -259,7 +259,7 @@ const AgentTools: FC = () => {
)} - {!item.isDeleted && ( + {!item.isDeleted && !readonly && (
{!item.notAuthor && ( { {!item.notAuthor && ( { const newModelConfig = produce(modelConfig, (draft) => { @@ -312,6 +312,7 @@ const AgentTools: FC = () => { {item.notAuthor && (
-
-
- -
+ {!readonly && ( +
+
+ +
+ )}
) } diff --git a/web/app/components/app/configuration/config/config-document.tsx b/web/app/components/app/configuration/config/config-document.tsx index 3f192fd401..7d48c1582a 100644 --- a/web/app/components/app/configuration/config/config-document.tsx +++ b/web/app/components/app/configuration/config/config-document.tsx @@ -17,7 +17,7 @@ const ConfigDocument: FC = () => { const { t } = useTranslation() const file = useFeatures(s => s.features.file) const featuresStore = useFeaturesStore() - const { isShowDocumentConfig } = useContext(ConfigContext) + const { isShowDocumentConfig, readonly } = useContext(ConfigContext) const isDocumentEnabled = file?.allowed_file_types?.includes(SupportUploadFileTypes.document) ?? false @@ -45,7 +45,7 @@ const ConfigDocument: FC = () => { setFeatures(newFeatures) }, [featuresStore]) - if (!isShowDocumentConfig) + if (!isShowDocumentConfig || (readonly && !isDocumentEnabled)) return null return ( @@ -65,14 +65,16 @@ const ConfigDocument: FC = () => { )} /> -
-
- -
+ {!readonly && ( +
+
+ +
+ )} ) } diff --git a/web/app/components/app/configuration/config/index.tsx b/web/app/components/app/configuration/config/index.tsx index f208b99e59..3e2b201172 100644 --- a/web/app/components/app/configuration/config/index.tsx +++ b/web/app/components/app/configuration/config/index.tsx @@ -18,6 +18,7 @@ import ConfigDocument from './config-document' const Config: FC = () => { const { + readonly, mode, isAdvancedMode, modelModeType, @@ -27,6 +28,7 @@ const Config: FC = () => { modelConfig, setModelConfig, setPrevPromptConfig, + dataSets, } = useContext(ConfigContext) const isChatApp = [AppModeEnum.ADVANCED_CHAT, AppModeEnum.AGENT_CHAT, AppModeEnum.CHAT].includes(mode) const formattingChangedDispatcher = useFormattingChangedDispatcher() @@ -65,19 +67,27 @@ const Config: FC = () => { promptTemplate={promptTemplate} promptVariables={promptVariables} onChange={handlePromptChange} + readonly={readonly} /> {/* Variables */} - + {!(readonly && promptVariables.length === 0) && ( + + )} {/* Dataset */} - - + {!(readonly && dataSets.length === 0) && ( + + )} {/* Tools */} - {isAgent && ( + {isAgent && !(readonly && modelConfig.agentConfig.tools.length === 0) && ( )} @@ -88,7 +98,7 @@ const Config: FC = () => { {/* Chat History */} - {isAdvancedMode && isChatApp && modelModeType === ModelModeType.completion && ( + {!readonly && isAdvancedMode && isChatApp && modelModeType === ModelModeType.completion && ( { expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ name: 'Updated dataset' })) }) await waitFor(() => { - expect(screen.getByText('Mock settings modal')).not.toBeVisible() + expect(screen.queryByText('Mock settings modal')).not.toBeInTheDocument() }) }) diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.tsx index 00d3f6d6ad..a5ad3312ec 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.tsx @@ -30,6 +30,7 @@ const Item: FC = ({ config, onSave, onRemove, + readonly = false, editable = true, }) => { const media = useBreakpoints() @@ -56,6 +57,7 @@ const Item: FC = ({
@@ -70,7 +72,7 @@ const Item: FC = ({
{ - editable && ( + editable && !readonly && ( { e.stopPropagation() @@ -81,14 +83,18 @@ const Item: FC = ({ ) } - onRemove(config.id)} - state={isDeleting ? ActionButtonState.Destructive : ActionButtonState.Default} - onMouseEnter={() => setIsDeleting(true)} - onMouseLeave={() => setIsDeleting(false)} - > - - + { + !readonly && ( + onRemove(config.id)} + state={isDeleting ? ActionButtonState.Destructive : ActionButtonState.Default} + onMouseEnter={() => setIsDeleting(true)} + onMouseLeave={() => setIsDeleting(false)} + > + + + ) + }
{ !!config.indexing_technique && ( @@ -107,11 +113,13 @@ const Item: FC = ({ ) } setShowSettingsModal(false)} footer={null} mask={isMobile} panelClassName="mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl"> - setShowSettingsModal(false)} - onSave={handleSave} - /> + {showSettingsModal && ( + setShowSettingsModal(false)} + onSave={handleSave} + /> + )}
) diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx index 309c6e7ddb..6de77cad9e 100644 --- a/web/app/components/app/configuration/dataset-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -30,6 +30,7 @@ import { import { useSelector as useAppContextSelector } from '@/context/app-context' import ConfigContext from '@/context/debug-configuration' import { AppModeEnum } from '@/types/app' +import { cn } from '@/utils/classnames' import { hasEditPermissionForDataset } from '@/utils/permission' import FeaturePanel from '../base/feature-panel' import OperationBtn from '../base/operation-btn' @@ -38,7 +39,11 @@ import CardItem from './card-item' import ContextVar from './context-var' import ParamsConfig from './params-config' -const DatasetConfig: FC = () => { +type Props = { + readonly?: boolean + hideMetadataFilter?: boolean +} +const DatasetConfig: FC = ({ readonly, hideMetadataFilter }) => { const { t } = useTranslation() const userProfile = useAppContextSelector(s => s.userProfile) const { @@ -259,17 +264,19 @@ const DatasetConfig: FC = () => { className="mt-2" title={t('feature.dataSet.title', { ns: 'appDebug' })} headerRight={( -
- {!isAgent && } - -
+ !readonly && ( +
+ {!isAgent && } + +
+ ) )} hasHeaderBottomBorder={!hasData} noBodySpacing > {hasData ? ( -
+
{formattedDataset.map(item => ( { onRemove={onRemove} onSave={handleSave} editable={item.editable} + readonly={readonly} /> ))}
@@ -287,27 +295,29 @@ const DatasetConfig: FC = () => {
)} -
- item.type === MetadataFilteringVariableType.string || item.type === MetadataFilteringVariableType.select)} - availableCommonNumberVars={promptVariablesToSelect.filter(item => item.type === MetadataFilteringVariableType.number)} - /> -
+ {!hideMetadataFilter && ( +
+ item.type === MetadataFilteringVariableType.string || item.type === MetadataFilteringVariableType.select)} + availableCommonNumberVars={promptVariablesToSelect.filter(item => item.type === MetadataFilteringVariableType.number)} + /> +
+ )} - {mode === AppModeEnum.COMPLETION && dataSet.length > 0 && ( + {!readonly && mode === AppModeEnum.COMPLETION && dataSet.length > 0 && ( { const { t } = useTranslation() - const { modelConfig, setInputs } = useContext(ConfigContext) + const { modelConfig, setInputs, readonly } = useContext(ConfigContext) const promptVariables = modelConfig.configs.prompt_variables.filter(({ key, name }) => { return key && key?.trim() && name && name?.trim() @@ -88,6 +88,7 @@ const ChatUserInput = ({ placeholder={name} autoFocus={index === 0} maxLength={max_length} + readOnly={readonly} /> )} {type === 'paragraph' && ( @@ -96,6 +97,7 @@ const ChatUserInput = ({ placeholder={name} value={inputs[key] ? `${inputs[key]}` : ''} onChange={(e) => { handleInputValueChange(key, e.target.value) }} + readOnly={readonly} /> )} {type === 'select' && ( @@ -105,6 +107,7 @@ const ChatUserInput = ({ onSelect={(i) => { handleInputValueChange(key, i.value as string) }} items={(options || []).map(i => ({ name: i, value: i }))} allowSearch={false} + disabled={readonly} /> )} {type === 'number' && ( @@ -115,6 +118,7 @@ const ChatUserInput = ({ placeholder={name} autoFocus={index === 0} maxLength={max_length} + readOnly={readonly} /> )} {type === 'checkbox' && ( @@ -123,6 +127,7 @@ const ChatUserInput = ({ value={!!inputs[key]} required={required} onChange={(value) => { handleInputValueChange(key, value) }} + readonly={readonly} /> )} diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/text-generation-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/text-generation-item.tsx index d7918e7ad6..eb18ca45b1 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/text-generation-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/text-generation-item.tsx @@ -15,6 +15,7 @@ import { DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/ import { useDebugConfigurationContext } from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' import { useProviderContext } from '@/context/provider-context' +import { AppSourceType } from '@/service/share' import { promptVariablesToUserInputsForm } from '@/utils/model-config' import { APP_CHAT_WITH_MULTIPLE_MODEL } from '../types' @@ -130,11 +131,11 @@ const TextGenerationItem: FC = ({ return ( { const { userProfile } = useAppContext() const { + readonly, modelConfig, appId, inputs, @@ -150,6 +151,7 @@ const DebugWithSingleModel = ( return ( = ({ }) => { const { t } = useTranslation() const { + readonly, appId, mode, modelModeType, @@ -416,25 +418,33 @@ const Debug: FC = ({ } {mode !== AppModeEnum.COMPLETION && ( <> - - - - - - {varList.length > 0 && ( -
+ { + !readonly && ( - setExpanded(!expanded)}> - + + + - {expanded &&
} -
- )} + ) + } + + { + varList.length > 0 && ( +
+ + !readonly && setExpanded(!expanded)}> + + + + {expanded &&
} +
+ ) + } )}
@@ -444,19 +454,21 @@ const Debug: FC = ({
)} - {mode === AppModeEnum.COMPLETION && ( - - )} + { + mode === AppModeEnum.COMPLETION && ( + + ) + } { debugWithMultipleModel && ( @@ -510,12 +522,12 @@ const Debug: FC = ({
= ({
) } - {isShowFormattingChangeConfirm && ( - - )} - {!isAPIKeySet && ()} + { + isShowFormattingChangeConfirm && ( + + ) + } + {!isAPIKeySet && !readonly && ()} ) } diff --git a/web/app/components/app/configuration/prompt-value-panel/index.tsx b/web/app/components/app/configuration/prompt-value-panel/index.tsx index 613efb8710..e695616810 100644 --- a/web/app/components/app/configuration/prompt-value-panel/index.tsx +++ b/web/app/components/app/configuration/prompt-value-panel/index.tsx @@ -40,7 +40,7 @@ const PromptValuePanel: FC = ({ onVisionFilesChange, }) => { const { t } = useTranslation() - const { modelModeType, modelConfig, setInputs, mode, isAdvancedMode, completionPromptConfig, chatPromptConfig } = useContext(ConfigContext) + const { readonly, modelModeType, modelConfig, setInputs, mode, isAdvancedMode, completionPromptConfig, chatPromptConfig } = useContext(ConfigContext) const [userInputFieldCollapse, setUserInputFieldCollapse] = useState(false) const promptVariables = modelConfig.configs.prompt_variables.filter(({ key, name }) => { return key && key?.trim() && name && name?.trim() @@ -78,12 +78,12 @@ const PromptValuePanel: FC = ({ if (isAdvancedMode) { if (modelModeType === ModelModeType.chat) - return chatPromptConfig.prompt.every(({ text }) => !text) + return chatPromptConfig?.prompt.every(({ text }) => !text) return !completionPromptConfig.prompt?.text } else { return !modelConfig.configs.prompt_template } - }, [chatPromptConfig.prompt, completionPromptConfig.prompt?.text, isAdvancedMode, mode, modelConfig.configs.prompt_template, modelModeType]) + }, [chatPromptConfig?.prompt, completionPromptConfig.prompt?.text, isAdvancedMode, mode, modelConfig.configs.prompt_template, modelModeType]) const handleInputValueChange = (key: string, value: string | boolean) => { if (!(key in promptVariableObj)) @@ -142,6 +142,7 @@ const PromptValuePanel: FC = ({ placeholder={name} autoFocus={index === 0} maxLength={max_length} + readOnly={readonly} /> )} {type === 'paragraph' && ( @@ -150,6 +151,7 @@ const PromptValuePanel: FC = ({ placeholder={name} value={inputs[key] ? `${inputs[key]}` : ''} onChange={(e) => { handleInputValueChange(key, e.target.value) }} + readOnly={readonly} /> )} {type === 'select' && ( @@ -160,6 +162,7 @@ const PromptValuePanel: FC = ({ items={(options || []).map(i => ({ name: i, value: i }))} allowSearch={false} bgClassName="bg-gray-50" + disabled={readonly} /> )} {type === 'number' && ( @@ -170,6 +173,7 @@ const PromptValuePanel: FC = ({ placeholder={name} autoFocus={index === 0} maxLength={max_length} + readOnly={readonly} /> )} {type === 'checkbox' && ( @@ -178,6 +182,7 @@ const PromptValuePanel: FC = ({ value={!!inputs[key]} required={required} onChange={(value) => { handleInputValueChange(key, value) }} + readonly={readonly} /> )} @@ -196,6 +201,7 @@ const PromptValuePanel: FC = ({ url: fileItem.url, upload_file_id: fileItem.fileId, })))} + disabled={readonly} /> @@ -204,12 +210,12 @@ const PromptValuePanel: FC = ({ )} {!userInputFieldCollapse && (
- + {canNotRun && (
diff --git a/web/app/components/app/create-app-dialog/app-card/index.spec.tsx b/web/app/components/app/create-app-dialog/app-card/index.spec.tsx index e1f9773ac3..82e4fb8f94 100644 --- a/web/app/components/app/create-app-dialog/app-card/index.spec.tsx +++ b/web/app/components/app/create-app-dialog/app-card/index.spec.tsx @@ -10,6 +10,7 @@ vi.mock('@heroicons/react/20/solid', () => ({ })) const mockApp: App = { + can_trial: true, app: { id: 'test-app-id', mode: AppModeEnum.CHAT, diff --git a/web/app/components/app/create-app-dialog/app-card/index.tsx b/web/app/components/app/create-app-dialog/app-card/index.tsx index 695faed5e0..15cfbd5411 100644 --- a/web/app/components/app/create-app-dialog/app-card/index.tsx +++ b/web/app/components/app/create-app-dialog/app-card/index.tsx @@ -1,9 +1,14 @@ 'use client' import type { App } from '@/models/explore' import { PlusIcon } from '@heroicons/react/20/solid' +import { RiInformation2Line } from '@remixicon/react' +import { useCallback } from 'react' import { useTranslation } from 'react-i18next' +import { useContextSelector } from 'use-context-selector' import AppIcon from '@/app/components/base/app-icon' import Button from '@/app/components/base/button' +import AppListContext from '@/context/app-list-context' +import { useGlobalPublicStore } from '@/context/global-public-context' import { cn } from '@/utils/classnames' import { AppTypeIcon, AppTypeLabel } from '../../type-selector' @@ -20,6 +25,14 @@ const AppCard = ({ }: AppCardProps) => { const { t } = useTranslation() const { app: appBasicInfo } = app + const { systemFeatures } = useGlobalPublicStore() + const isTrialApp = app.can_trial && systemFeatures.enable_trial_app + const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel) + const showTryAPPPanel = useCallback((appId: string) => { + return () => { + setShowTryAppPanel?.(true, { appId, app }) + } + }, [setShowTryAppPanel, app.category]) return (
@@ -51,11 +64,17 @@ const AppCard = ({
{canCreate && ( )} diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index 133bd34dbc..778a2c1420 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -58,7 +58,7 @@ const Uploader: FC = ({ setDragging(false) if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length > 1) { notify({ type: 'error', message: t('stepOne.uploader.validation.count', { ns: 'datasetCreation' }) }) return diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 410953ccf7..5197a02bb3 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -39,6 +39,7 @@ import { useAppContext } from '@/context/app-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useTimestamp from '@/hooks/use-timestamp' import { fetchChatMessages, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log' +import { AppSourceType } from '@/service/share' import { useChatConversationDetail, useCompletionConversationDetail } from '@/service/use-log' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' @@ -638,12 +639,12 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
item.from_source === 'admin')} onFeedback={feedback => onFeedback(detail.message.id, feedback)} diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index 78f4f426f5..c39282a022 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -29,7 +29,7 @@ import { Markdown } from '@/app/components/base/markdown' import NewAudioButton from '@/app/components/base/new-audio-button' import Toast from '@/app/components/base/toast' import { fetchTextGenerationMessage } from '@/service/debug' -import { fetchMoreLikeThis, updateFeedback } from '@/service/share' +import { AppSourceType, fetchMoreLikeThis, updateFeedback } from '@/service/share' import { cn } from '@/utils/classnames' import ResultTab from './result-tab' @@ -53,7 +53,7 @@ export type IGenerationItemProps = { onFeedback?: (feedback: FeedbackType) => void onSave?: (messageId: string) => void isMobile?: boolean - isInstalledApp: boolean + appSourceType: AppSourceType installedAppId?: string taskId?: string controlClearMoreLikeThis?: number @@ -87,7 +87,7 @@ const GenerationItem: FC = ({ onSave, depth = 1, isMobile, - isInstalledApp, + appSourceType, installedAppId, taskId, controlClearMoreLikeThis, @@ -100,6 +100,7 @@ const GenerationItem: FC = ({ const { t } = useTranslation() const params = useParams() const isTop = depth === 1 + const isTryApp = appSourceType === AppSourceType.tryApp const [completionRes, setCompletionRes] = useState('') const [childMessageId, setChildMessageId] = useState(null) const [childFeedback, setChildFeedback] = useState({ @@ -113,7 +114,7 @@ const GenerationItem: FC = ({ const setShowPromptLogModal = useAppStore(s => s.setShowPromptLogModal) const handleFeedback = async (childFeedback: FeedbackType) => { - await updateFeedback({ url: `/messages/${childMessageId}/feedbacks`, body: { rating: childFeedback.rating } }, isInstalledApp, installedAppId) + await updateFeedback({ url: `/messages/${childMessageId}/feedbacks`, body: { rating: childFeedback.rating } }, appSourceType, installedAppId) setChildFeedback(childFeedback) } @@ -131,7 +132,7 @@ const GenerationItem: FC = ({ onSave, isShowTextToSpeech, isMobile, - isInstalledApp, + appSourceType, installedAppId, controlClearMoreLikeThis, isWorkflow, @@ -145,7 +146,7 @@ const GenerationItem: FC = ({ return } startQuerying() - const res: any = await fetchMoreLikeThis(messageId as string, isInstalledApp, installedAppId) + const res: any = await fetchMoreLikeThis(messageId as string, appSourceType, installedAppId) setCompletionRes(res.answer) setChildFeedback({ rating: null, @@ -310,7 +311,7 @@ const GenerationItem: FC = ({ )} {/* action buttons */}
- {!isInWebApp && !isInstalledApp && !isResponding && ( + {!isInWebApp && (appSourceType !== AppSourceType.installedApp) && !isResponding && (
@@ -319,12 +320,12 @@ const GenerationItem: FC = ({
)}
- {moreLikeThis && ( + {moreLikeThis && !isTryApp && ( )} - {isShowTextToSpeech && ( + {isShowTextToSpeech && !isTryApp && ( = ({ )} - {isInWebApp && !isWorkflow && ( + {isInWebApp && !isWorkflow && !isTryApp && ( { onSave?.(messageId as string) }}> )}
- {(supportFeedback || isInWebApp) && !isWorkflow && !isError && messageId && ( + {(supportFeedback || isInWebApp) && !isWorkflow && !isTryApp && !isError && messageId && (
{!feedback?.rating && ( <> diff --git a/web/app/components/apps/hooks/use-dsl-drag-drop.ts b/web/app/components/apps/hooks/use-dsl-drag-drop.ts index dda5773062..77d89b87da 100644 --- a/web/app/components/apps/hooks/use-dsl-drag-drop.ts +++ b/web/app/components/apps/hooks/use-dsl-drag-drop.ts @@ -36,7 +36,7 @@ export const useDSLDragDrop = ({ onDSLFileDropped, containerRef, enabled = true if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length === 0) return diff --git a/web/app/components/apps/index.spec.tsx b/web/app/components/apps/index.spec.tsx index c3dc39955d..c77c1bdb01 100644 --- a/web/app/components/apps/index.spec.tsx +++ b/web/app/components/apps/index.spec.tsx @@ -1,3 +1,5 @@ +import type { ReactNode } from 'react' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { render, screen } from '@testing-library/react' import * as React from 'react' @@ -22,6 +24,15 @@ vi.mock('@/app/education-apply/hooks', () => ({ }, })) +vi.mock('@/hooks/use-import-dsl', () => ({ + useImportDSL: () => ({ + handleImportDSL: vi.fn(), + handleImportDSLConfirm: vi.fn(), + versions: [], + isFetching: false, + }), +})) + // Mock List component vi.mock('./list', () => ({ default: () => { @@ -30,6 +41,25 @@ vi.mock('./list', () => ({ })) describe('Apps', () => { + const createQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + + const renderWithClient = (ui: React.ReactElement) => { + const queryClient = createQueryClient() + const wrapper = ({ children }: { children: ReactNode }) => ( + {children} + ) + return { + queryClient, + ...render(ui, { wrapper }), + } + } + beforeEach(() => { vi.clearAllMocks() documentTitleCalls = [] @@ -38,17 +68,17 @@ describe('Apps', () => { describe('Rendering', () => { it('should render without crashing', () => { - render() + renderWithClient() expect(screen.getByTestId('apps-list')).toBeInTheDocument() }) it('should render List component', () => { - render() + renderWithClient() expect(screen.getByText('Apps List')).toBeInTheDocument() }) it('should have correct container structure', () => { - const { container } = render() + const { container } = renderWithClient() const wrapper = container.firstChild as HTMLElement expect(wrapper).toHaveClass('relative', 'flex', 'h-0', 'shrink-0', 'grow', 'flex-col') }) @@ -56,19 +86,19 @@ describe('Apps', () => { describe('Hooks', () => { it('should call useDocumentTitle with correct title', () => { - render() + renderWithClient() expect(documentTitleCalls).toContain('common.menus.apps') }) it('should call useEducationInit', () => { - render() + renderWithClient() expect(educationInitCalls).toBeGreaterThan(0) }) }) describe('Integration', () => { it('should render full component tree', () => { - render() + renderWithClient() // Verify container exists expect(screen.getByTestId('apps-list')).toBeInTheDocument() @@ -79,23 +109,32 @@ describe('Apps', () => { }) it('should handle multiple renders', () => { - const { rerender } = render() + const queryClient = createQueryClient() + const { rerender } = render( + + + , + ) expect(screen.getByTestId('apps-list')).toBeInTheDocument() - rerender() + rerender( + + + , + ) expect(screen.getByTestId('apps-list')).toBeInTheDocument() }) }) describe('Styling', () => { it('should have overflow-y-auto class', () => { - const { container } = render() + const { container } = renderWithClient() const wrapper = container.firstChild as HTMLElement expect(wrapper).toHaveClass('overflow-y-auto') }) it('should have background styling', () => { - const { container } = render() + const { container } = renderWithClient() const wrapper = container.firstChild as HTMLElement expect(wrapper).toHaveClass('bg-background-body') }) diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index b151df1e1f..255bfbf9c5 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -1,7 +1,17 @@ 'use client' +import type { CreateAppModalProps } from '../explore/create-app-modal' +import type { CurrentTryAppParams } from '@/context/explore-context' +import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useEducationInit } from '@/app/education-apply/hooks' +import AppListContext from '@/context/app-list-context' import useDocumentTitle from '@/hooks/use-document-title' +import { useImportDSL } from '@/hooks/use-import-dsl' +import { DSLImportMode } from '@/models/app' +import { fetchAppDetail } from '@/service/explore' +import DSLConfirmModal from '../app/create-from-dsl-modal/dsl-confirm-modal' +import CreateAppModal from '../explore/create-app-modal' +import TryApp from '../explore/try-app' import List from './list' const Apps = () => { @@ -10,10 +20,124 @@ const Apps = () => { useDocumentTitle(t('menus.apps', { ns: 'common' })) useEducationInit() + const [currentTryAppParams, setCurrentTryAppParams] = useState(undefined) + const currApp = currentTryAppParams?.app + const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false) + const hideTryAppPanel = useCallback(() => { + setIsShowTryAppPanel(false) + }, []) + const setShowTryAppPanel = (showTryAppPanel: boolean, params?: CurrentTryAppParams) => { + if (showTryAppPanel) + setCurrentTryAppParams(params) + else + setCurrentTryAppParams(undefined) + setIsShowTryAppPanel(showTryAppPanel) + } + const [isShowCreateModal, setIsShowCreateModal] = useState(false) + + const handleShowFromTryApp = useCallback(() => { + setIsShowCreateModal(true) + }, []) + + const [controlRefreshList, setControlRefreshList] = useState(0) + const [controlHideCreateFromTemplatePanel, setControlHideCreateFromTemplatePanel] = useState(0) + const onSuccess = useCallback(() => { + setControlRefreshList(prev => prev + 1) + setControlHideCreateFromTemplatePanel(prev => prev + 1) + }, []) + + const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false) + + const { + handleImportDSL, + handleImportDSLConfirm, + versions, + isFetching, + } = useImportDSL() + + const onConfirmDSL = useCallback(async () => { + await handleImportDSLConfirm({ + onSuccess, + }) + }, [handleImportDSLConfirm, onSuccess]) + + const onCreate: CreateAppModalProps['onConfirm'] = async ({ + name, + icon_type, + icon, + icon_background, + description, + }) => { + hideTryAppPanel() + + const { export_data } = await fetchAppDetail( + currApp?.app.id as string, + ) + const payload = { + mode: DSLImportMode.YAML_CONTENT, + yaml_content: export_data, + name, + icon_type, + icon, + icon_background, + description, + } + await handleImportDSL(payload, { + onSuccess: () => { + setIsShowCreateModal(false) + }, + onPending: () => { + setShowDSLConfirmModal(true) + }, + }) + } + return ( -
- -
+ +
+ + {isShowTryAppPanel && ( + + )} + + { + showDSLConfirmModal && ( + setShowDSLConfirmModal(false)} + onConfirm={onConfirmDSL} + confirmDisabled={isFetching} + /> + ) + } + + {isShowCreateModal && ( + setIsShowCreateModal(false)} + /> + )} +
+
) } diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 8a236fe260..6bf79b7338 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -1,5 +1,6 @@ 'use client' +import type { FC } from 'react' import { RiApps2Line, RiDragDropLine, @@ -53,7 +54,12 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro ssr: false, }) -const List = () => { +type Props = { + controlRefreshList?: number +} +const List: FC = ({ + controlRefreshList = 0, +}) => { const { t } = useTranslation() const { systemFeatures } = useGlobalPublicStore() const router = useRouter() @@ -110,6 +116,13 @@ const List = () => { refetch, } = useInfiniteAppList(appListQueryParams, { enabled: !isCurrentWorkspaceDatasetOperator }) + useEffect(() => { + if (controlRefreshList > 0) { + refetch() + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [controlRefreshList]) + const anchorRef = useRef(null) const options = [ { value: 'all', text: t('types.all', { ns: 'app' }), icon: }, diff --git a/web/app/components/apps/new-app-card.tsx b/web/app/components/apps/new-app-card.tsx index bfa7af3892..868da0dcb5 100644 --- a/web/app/components/apps/new-app-card.tsx +++ b/web/app/components/apps/new-app-card.tsx @@ -6,10 +6,12 @@ import { useSearchParams, } from 'next/navigation' import * as React from 'react' -import { useMemo, useState } from 'react' +import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' +import { useContextSelector } from 'use-context-selector' import { CreateFromDSLModalTab } from '@/app/components/app/create-from-dsl-modal' import { FileArrow01, FilePlus01, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' +import AppListContext from '@/context/app-list-context' import { useProviderContext } from '@/context/provider-context' import { cn } from '@/utils/classnames' @@ -55,6 +57,13 @@ const CreateAppCard = ({ return undefined }, [dslUrl]) + const controlHideCreateFromTemplatePanel = useContextSelector(AppListContext, ctx => ctx.controlHideCreateFromTemplatePanel) + useEffect(() => { + if (controlHideCreateFromTemplatePanel > 0) + // eslint-disable-next-line react-hooks-extra/no-direct-set-state-in-use-effect + setShowNewAppTemplateDialog(false) + }, [controlHideCreateFromTemplatePanel]) + return (
{ +const ActionButton = ({ className, size, state = ActionButtonState.Default, styleCss, children, ref, disabled, ...props }: ActionButtonProps) => { return ( + ) + }, +) +CarouselPrevious.displayName = 'CarouselPrevious' + +const CarouselNext = React.forwardRef( + ({ children, ...props }, ref) => { + const { scrollNext, canScrollNext } = useCarousel() + + return ( + + ) + }, +) +CarouselNext.displayName = 'CarouselNext' + +const CarouselDot = React.forwardRef( + ({ children, ...props }, ref) => { + const { api, selectedIndex } = useCarousel() + + return api?.slideNodes().map((_, index) => { + return ( + + ) + }) + }, +) +CarouselDot.displayName = 'CarouselDot' + +const CarouselPlugins = { + Autoplay, +} + +Carousel.Content = CarouselContent +Carousel.Item = CarouselItem +Carousel.Previous = CarouselPrevious +Carousel.Next = CarouselNext +Carousel.Dot = CarouselDot +Carousel.Plugin = CarouselPlugins + +export { Carousel, useCarousel } diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 25ff39370f..38a3f6c6b2 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -12,6 +12,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested import { Markdown } from '@/app/components/base/markdown' import { InputVarType } from '@/app/components/workflow/types' import { + AppSourceType, fetchSuggestedQuestions, getUrl, stopChatMessageResponding, @@ -52,6 +53,11 @@ const ChatWrapper = () => { initUserVariables, } = useChatWithHistoryContext() + const appSourceType = isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp + + // Semantic variable for better code readability + const isHistoryConversation = !!currentConversationId + const appConfig = useMemo(() => { const config = appParams || {} @@ -79,7 +85,7 @@ const ChatWrapper = () => { inputsForm: inputsForms, }, appPrevChatTree, - taskId => stopChatMessageResponding('', taskId, isInstalledApp, appId), + taskId => stopChatMessageResponding('', taskId, appSourceType, appId), clearChatList, setClearChatList, ) @@ -138,11 +144,11 @@ const ChatWrapper = () => { } handleSend( - getUrl('chat-messages', isInstalledApp, appId || ''), + getUrl('chat-messages', appSourceType, appId || ''), data, { - onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, isInstalledApp, appId), - onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted, + onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), + onConversationComplete: isHistoryConversation ? undefined : handleNewConversationCompleted, isPublicAPI: !isInstalledApp, }, ) diff --git a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx index f6a8f25cbb..399f16716d 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx @@ -5,6 +5,7 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { act, renderHook, waitFor } from '@testing-library/react' import { ToastProvider } from '@/app/components/base/toast' import { + AppSourceType, fetchChatList, fetchConversations, generationConversationName, @@ -49,20 +50,24 @@ vi.mock('../utils', async () => { } }) -vi.mock('@/service/share', () => ({ - fetchChatList: vi.fn(), - fetchConversations: vi.fn(), - generationConversationName: vi.fn(), - fetchAppInfo: vi.fn(), - fetchAppMeta: vi.fn(), - fetchAppParams: vi.fn(), - getAppAccessModeByAppCode: vi.fn(), - delConversation: vi.fn(), - pinConversation: vi.fn(), - renameConversation: vi.fn(), - unpinConversation: vi.fn(), - updateFeedback: vi.fn(), -})) +vi.mock('@/service/share', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + fetchChatList: vi.fn(), + fetchConversations: vi.fn(), + generationConversationName: vi.fn(), + fetchAppInfo: vi.fn(), + fetchAppMeta: vi.fn(), + fetchAppParams: vi.fn(), + getAppAccessModeByAppCode: vi.fn(), + delConversation: vi.fn(), + pinConversation: vi.fn(), + renameConversation: vi.fn(), + unpinConversation: vi.fn(), + updateFeedback: vi.fn(), + } +}) const mockFetchConversations = vi.mocked(fetchConversations) const mockFetchChatList = vi.mocked(fetchChatList) @@ -162,13 +167,13 @@ describe('useChatWithHistory', () => { // Assert await waitFor(() => { - expect(mockFetchConversations).toHaveBeenCalledWith(false, 'app-1', undefined, true, 100) + expect(mockFetchConversations).toHaveBeenCalledWith(AppSourceType.webApp, 'app-1', undefined, true, 100) }) await waitFor(() => { - expect(mockFetchConversations).toHaveBeenCalledWith(false, 'app-1', undefined, false, 100) + expect(mockFetchConversations).toHaveBeenCalledWith(AppSourceType.webApp, 'app-1', undefined, false, 100) }) await waitFor(() => { - expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', false, 'app-1') + expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', AppSourceType.webApp, 'app-1') }) await waitFor(() => { expect(result.current.pinnedConversationList).toEqual(pinnedData.data) @@ -204,7 +209,7 @@ describe('useChatWithHistory', () => { // Assert await waitFor(() => { - expect(mockGenerationConversationName).toHaveBeenCalledWith(false, 'app-1', 'conversation-new') + expect(mockGenerationConversationName).toHaveBeenCalledWith(AppSourceType.webApp, 'app-1', 'conversation-new') }) await waitFor(() => { expect(result.current.conversationList[0]).toEqual(generatedConversation) diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index ed1981b530..ad1de38d07 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -27,6 +27,7 @@ import { useWebAppStore } from '@/context/web-app-context' import { useAppFavicon } from '@/hooks/use-app-favicon' import { changeLanguage } from '@/i18n-config/client' import { + AppSourceType, delConversation, pinConversation, renameConversation, @@ -72,6 +73,7 @@ function getFormattedChatList(messages: any[]) { export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const isInstalledApp = useMemo(() => !!installedAppInfo, [installedAppInfo]) + const appSourceType = isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp const appInfo = useWebAppStore(s => s.appInfo) const appParams = useWebAppStore(s => s.appParams) const appMeta = useWebAppStore(s => s.appMeta) @@ -177,7 +179,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }, [currentConversationId, newConversationId]) const { data: appPinnedConversationData } = useShareConversations({ - isInstalledApp, + appSourceType, appId, pinned: true, limit: 100, @@ -190,7 +192,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { data: appConversationData, isLoading: appConversationDataLoading, } = useShareConversations({ - isInstalledApp, + appSourceType, appId, pinned: false, limit: 100, @@ -204,7 +206,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { isLoading: appChatListDataLoading, } = useShareChatList({ conversationId: chatShouldReloadKey, - isInstalledApp, + appSourceType, appId, }, { enabled: !!chatShouldReloadKey, @@ -334,10 +336,11 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const { data: newConversation } = useShareConversationName({ conversationId: newConversationId, - isInstalledApp, + appSourceType, appId, }, { refetchOnWindowFocus: false, + enabled: !!newConversationId, }) const [originConversationList, setOriginConversationList] = useState([]) useEffect(() => { @@ -462,16 +465,16 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }, [invalidateShareConversations]) const handlePinConversation = useCallback(async (conversationId: string) => { - await pinConversation(isInstalledApp, appId, conversationId) + await pinConversation(appSourceType, appId, conversationId) notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) handleUpdateConversationList() - }, [isInstalledApp, appId, notify, t, handleUpdateConversationList]) + }, [appSourceType, appId, notify, t, handleUpdateConversationList]) const handleUnpinConversation = useCallback(async (conversationId: string) => { - await unpinConversation(isInstalledApp, appId, conversationId) + await unpinConversation(appSourceType, appId, conversationId) notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) handleUpdateConversationList() - }, [isInstalledApp, appId, notify, t, handleUpdateConversationList]) + }, [appSourceType, appId, notify, t, handleUpdateConversationList]) const [conversationDeleting, setConversationDeleting] = useState(false) const handleDeleteConversation = useCallback(async ( @@ -485,7 +488,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { try { setConversationDeleting(true) - await delConversation(isInstalledApp, appId, conversationId) + await delConversation(appSourceType, appId, conversationId) notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) onSuccess() } @@ -520,7 +523,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { setConversationRenaming(true) try { - await renameConversation(isInstalledApp, appId, conversationId, newName) + await renameConversation(appSourceType, appId, conversationId, newName) notify({ type: 'success', @@ -550,9 +553,9 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }, [handleConversationIdInfoChange, invalidateShareConversations]) const handleFeedback = useCallback(async (messageId: string, feedback: Feedback) => { - await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating, content: feedback.content } }, isInstalledApp, appId) + await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating, content: feedback.content } }, appSourceType, appId) notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) - }, [isInstalledApp, appId, t, notify]) + }, [appSourceType, appId, t, notify]) return { isInstalledApp, diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 9f1efa3ae0..da46f47c61 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -150,7 +150,7 @@ const Answer: FC = ({ data={workflowProcess} item={item} hideProcessDetail={hideProcessDetail} - readonly={hideProcessDetail && appData ? !appData.site.show_workflow_steps : undefined} + readonly={hideProcessDetail && appData ? !appData.site?.show_workflow_steps : undefined} /> ) } diff --git a/web/app/components/base/chat/chat/answer/suggested-questions.tsx b/web/app/components/base/chat/chat/answer/suggested-questions.tsx index 019ed78348..ce997a49b6 100644 --- a/web/app/components/base/chat/chat/answer/suggested-questions.tsx +++ b/web/app/components/base/chat/chat/answer/suggested-questions.tsx @@ -1,6 +1,7 @@ import type { FC } from 'react' import type { ChatItem } from '../../types' import { memo } from 'react' +import { cn } from '@/utils/classnames' import { useChatContext } from '../context' type SuggestedQuestionsProps = { @@ -9,7 +10,7 @@ type SuggestedQuestionsProps = { const SuggestedQuestions: FC = ({ item, }) => { - const { onSend } = useChatContext() + const { onSend, readonly } = useChatContext() const { isOpeningStatement, @@ -24,8 +25,11 @@ const SuggestedQuestions: FC = ({ {suggestedQuestions.filter(q => !!q && q.trim()).map((question, index) => (
onSend?.(question)} + className={cn( + 'system-sm-medium mr-1 mt-1 inline-flex max-w-full shrink-0 cursor-pointer flex-wrap rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-3.5 py-2 text-components-button-secondary-accent-text shadow-xs last:mr-0 hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover', + readonly && 'pointer-events-none opacity-50', + )} + onClick={() => !readonly && onSend?.(question)} > {question}
diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 192f46fb23..9de52cb18c 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -5,6 +5,7 @@ import type { } from '../../types' import type { InputForm } from '../type' import type { FileUpload } from '@/app/components/base/features/types' +import { noop } from 'es-toolkit/function' import { decode } from 'html-entities' import Recorder from 'js-audio-recorder' import { @@ -30,6 +31,7 @@ import { useTextAreaHeight } from './hooks' import Operation from './operation' type ChatInputAreaProps = { + readonly?: boolean botName?: string showFeatureBar?: boolean showFileUpload?: boolean @@ -45,6 +47,7 @@ type ChatInputAreaProps = { disabled?: boolean } const ChatInputArea = ({ + readonly, botName, showFeatureBar, showFileUpload, @@ -170,6 +173,7 @@ const ChatInputArea = ({ const operation = (
{ @@ -239,7 +244,14 @@ const ChatInputArea = ({ ) }
- {showFeatureBar && } + {showFeatureBar && ( + + )} ) } diff --git a/web/app/components/base/chat/chat/chat-input-area/operation.tsx b/web/app/components/base/chat/chat/chat-input-area/operation.tsx index 27e5bf6cad..5bce827754 100644 --- a/web/app/components/base/chat/chat/chat-input-area/operation.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/operation.tsx @@ -8,6 +8,7 @@ import { RiMicLine, RiSendPlane2Fill, } from '@remixicon/react' +import { noop } from 'es-toolkit/function' import { memo } from 'react' import ActionButton from '@/app/components/base/action-button' import Button from '@/app/components/base/button' @@ -15,6 +16,7 @@ import { FileUploaderInChatInput } from '@/app/components/base/file-uploader' import { cn } from '@/utils/classnames' type OperationProps = { + readonly?: boolean fileConfig?: FileUpload speechToTextConfig?: EnableType onShowVoiceInput?: () => void @@ -23,6 +25,7 @@ type OperationProps = { ref?: Ref } const Operation: FC = ({ + readonly, ref, fileConfig, speechToTextConfig, @@ -41,11 +44,12 @@ const Operation: FC = ({ ref={ref} >
- {fileConfig?.enabled && } + {fileConfig?.enabled && } { speechToTextConfig?.enabled && ( @@ -56,7 +60,7 @@ const Operation: FC = ({ + { + !hideEditEntrance && ( + + ) + }
)}
diff --git a/web/app/components/base/file-uploader/file-uploader-in-chat-input/index.tsx b/web/app/components/base/file-uploader/file-uploader-in-chat-input/index.tsx index 1ae328d67a..08bb8b45d1 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-chat-input/index.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-chat-input/index.tsx @@ -13,21 +13,27 @@ import FileFromLinkOrLocal from '../file-from-link-or-local' type FileUploaderInChatInputProps = { fileConfig: FileUpload + readonly?: boolean } const FileUploaderInChatInput = ({ fileConfig, + readonly, }: FileUploaderInChatInputProps) => { const renderTrigger = useCallback((open: boolean) => { return ( ) }, []) + if (readonly) + return renderTrigger(false) + return ( = ({ type TextGenerationImageUploaderProps = { settings: VisionSettings onFilesChange: (files: ImageFile[]) => void + disabled?: boolean } const TextGenerationImageUploader: FC = ({ settings, onFilesChange, + disabled, }) => { const { t } = useTranslation() @@ -93,7 +95,7 @@ const TextGenerationImageUploader: FC = ({ const localUpload = ( = settings.number_limits} + disabled={files.length >= settings.number_limits || disabled} limit={+settings.image_file_size_limit!} > { @@ -115,7 +117,7 @@ const TextGenerationImageUploader: FC = ({ const urlUpload = ( = settings.number_limits} + disabled={files.length >= settings.number_limits || disabled} /> ) diff --git a/web/app/components/base/markdown/react-markdown-wrapper.spec.tsx b/web/app/components/base/markdown/react-markdown-wrapper.spec.tsx new file mode 100644 index 0000000000..735222011b --- /dev/null +++ b/web/app/components/base/markdown/react-markdown-wrapper.spec.tsx @@ -0,0 +1,109 @@ +import type { PropsWithChildren, ReactNode } from 'react' +import { render, screen } from '@testing-library/react' +import { ReactMarkdownWrapper } from './react-markdown-wrapper' + +vi.mock('@/app/components/base/markdown-blocks', () => ({ + AudioBlock: ({ children }: PropsWithChildren) =>
{children}
, + Img: ({ alt }: { alt?: string }) => {alt}, + Link: ({ children, href }: { children?: ReactNode, href?: string }) => {children}, + MarkdownButton: ({ children }: PropsWithChildren) => , + MarkdownForm: ({ children }: PropsWithChildren) =>
{children}
, + Paragraph: ({ children }: PropsWithChildren) =>

{children}

, + PluginImg: ({ alt }: { alt?: string }) => {alt}, + PluginParagraph: ({ children }: PropsWithChildren) =>

{children}

, + ScriptBlock: () => null, + ThinkBlock: ({ children }: PropsWithChildren) =>
{children}
, + VideoBlock: ({ children }: PropsWithChildren) =>
{children}
, +})) + +vi.mock('@/app/components/base/markdown-blocks/code-block', () => ({ + default: ({ children }: PropsWithChildren) => {children}, +})) + +describe('ReactMarkdownWrapper', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Strikethrough rendering', () => { + it('should NOT render single tilde as strikethrough', () => { + // Arrange - single tilde should be rendered as literal text + const content = 'Range: 0.3~8mm' + + // Act + render() + + // Assert - check that ~ is rendered as text, not as strikethrough (del element) + // The content should contain the tilde as literal text + expect(screen.getByText(/0\.3~8mm/)).toBeInTheDocument() + expect(document.querySelector('del')).toBeNull() + }) + + it('should render double tildes as strikethrough', () => { + // Arrange - double tildes should create strikethrough + const content = 'This is ~~strikethrough~~ text' + + // Act + render() + + // Assert - del element should be present for double tildes + const delElement = document.querySelector('del') + expect(delElement).not.toBeNull() + expect(delElement?.textContent).toBe('strikethrough') + }) + + it('should handle mixed content with single and double tildes correctly', () => { + // Arrange - real-world example from issue #31391 + const content = 'PCB thickness: 0.3~8mm and ~~removed feature~~ text' + + // Act + render() + + // Assert + // Only double tildes should create strikethrough + const delElements = document.querySelectorAll('del') + expect(delElements).toHaveLength(1) + expect(delElements[0].textContent).toBe('removed feature') + + // Single tilde should remain as literal text + expect(screen.getByText(/0\.3~8mm/)).toBeInTheDocument() + }) + }) + + describe('Basic rendering', () => { + it('should render plain text content', () => { + // Arrange + const content = 'Hello World' + + // Act + render() + + // Assert + expect(screen.getByText('Hello World')).toBeInTheDocument() + }) + + it('should render bold text', () => { + // Arrange + const content = '**bold text**' + + // Act + render() + + // Assert + expect(screen.getByText('bold text')).toBeInTheDocument() + expect(document.querySelector('strong')).not.toBeNull() + }) + + it('should render italic text', () => { + // Arrange + const content = '*italic text*' + + // Act + render() + + // Assert + expect(screen.getByText('italic text')).toBeInTheDocument() + expect(document.querySelector('em')).not.toBeNull() + }) + }) +}) diff --git a/web/app/components/base/markdown/react-markdown-wrapper.tsx b/web/app/components/base/markdown/react-markdown-wrapper.tsx index ef735b5e76..ed9e93e8b3 100644 --- a/web/app/components/base/markdown/react-markdown-wrapper.tsx +++ b/web/app/components/base/markdown/react-markdown-wrapper.tsx @@ -30,7 +30,7 @@ export const ReactMarkdownWrapper: FC = (props) => { return ( void } @@ -23,6 +25,8 @@ const TabHeader: FC = ({ items, value, itemClassName, + itemWrapClassName, + activeItemClassName, onChange, }) => { const renderItem = ({ id, name, icon, extra, disabled }: Item) => ( @@ -30,8 +34,9 @@ const TabHeader: FC = ({ key={id} className={cn( 'system-md-semibold relative flex cursor-pointer items-center border-b-2 border-transparent pb-2 pt-2.5', - id === value ? 'border-components-tab-active text-text-primary' : 'text-text-tertiary', + id === value ? cn('border-components-tab-active text-text-primary', activeItemClassName) : 'text-text-tertiary', disabled && 'cursor-not-allowed opacity-30', + itemWrapClassName, )} onClick={() => !disabled && onChange(id)} > diff --git a/web/app/components/base/voice-input/index.tsx b/web/app/components/base/voice-input/index.tsx index 4fa2c774f4..52e3c754f8 100644 --- a/web/app/components/base/voice-input/index.tsx +++ b/web/app/components/base/voice-input/index.tsx @@ -8,7 +8,7 @@ import { useParams, usePathname } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { StopCircle } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' -import { audioToText } from '@/service/share' +import { AppSourceType, audioToText } from '@/service/share' import { cn } from '@/utils/classnames' import s from './index.module.css' import { convertToMp3 } from './utils' @@ -108,7 +108,7 @@ const VoiceInput = ({ } try { - const audioResponse = await audioToText(url, isPublic, formData) + const audioResponse = await audioToText(url, isPublic ? AppSourceType.webApp : AppSourceType.installedApp, formData) onConverted(audioResponse.text) onCancel() } diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx index 2f5130ecce..3fa940c60d 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx @@ -54,7 +54,7 @@ const Uploader: FC = ({ setDragging(false) if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length > 1) { notify({ type: 'error', message: t('stepOne.uploader.validation.count', { ns: 'datasetCreation' }) }) return diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index e9c6693e52..781b97200a 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -278,7 +278,7 @@ const FileUploader = ({ onFileListUpdate?.([...fileListRef.current]) } const fileChangeHandle = useCallback((e: React.ChangeEvent) => { - let files = [...(e.target.files ?? [])] as File[] + let files = Array.from(e.target.files ?? []) as File[] files = files.slice(0, fileUploadConfig.batch_count_limit) initialUpload(files.filter(isValid)) }, [isValid, initialUpload, fileUploadConfig]) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx index a5c03b671a..d02d5927f2 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx @@ -230,7 +230,7 @@ const LocalFile = ({ if (!e.dataTransfer) return - let files = [...e.dataTransfer.files] as File[] + let files = Array.from(e.dataTransfer.files) as File[] if (!supportBatchUpload) files = files.slice(0, 1) @@ -251,7 +251,7 @@ const LocalFile = ({ updateFileList([...fileListRef.current]) } const fileChangeHandle = useCallback((e: React.ChangeEvent) => { - let files = [...(e.target.files ?? [])] as File[] + let files = Array.from(e.target.files ?? []) as File[] files = files.slice(0, fileUploadConfig.batch_count_limit) initialUpload(files.filter(isValid)) }, [isValid, initialUpload, fileUploadConfig.batch_count_limit]) diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx index 0ca404a26e..f3a86e910d 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx @@ -126,7 +126,7 @@ const CSVUploader: FC = ({ setDragging(false) if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length > 1) { notify({ type: 'error', message: t('stepOne.uploader.validation.count', { ns: 'datasetCreation' }) }) return diff --git a/web/app/components/explore/app-card/index.spec.tsx b/web/app/components/explore/app-card/index.spec.tsx index 769b317929..152eab92a9 100644 --- a/web/app/components/explore/app-card/index.spec.tsx +++ b/web/app/components/explore/app-card/index.spec.tsx @@ -10,6 +10,7 @@ vi.mock('../../app/type-selector', () => ({ })) const createApp = (overrides?: Partial): App => ({ + can_trial: true, app_id: 'app-id', description: 'App description', copyright: '2024', diff --git a/web/app/components/explore/app-card/index.tsx b/web/app/components/explore/app-card/index.tsx index 0b6cd9920d..5d82ab65cc 100644 --- a/web/app/components/explore/app-card/index.tsx +++ b/web/app/components/explore/app-card/index.tsx @@ -1,8 +1,13 @@ 'use client' import type { App } from '@/models/explore' import { PlusIcon } from '@heroicons/react/20/solid' +import { RiInformation2Line } from '@remixicon/react' +import { useCallback } from 'react' import { useTranslation } from 'react-i18next' +import { useContextSelector } from 'use-context-selector' import AppIcon from '@/app/components/base/app-icon' +import ExploreContext from '@/context/explore-context' +import { useGlobalPublicStore } from '@/context/global-public-context' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' import { AppTypeIcon } from '../../app/type-selector' @@ -23,8 +28,17 @@ const AppCard = ({ }: AppCardProps) => { const { t } = useTranslation() const { app: appBasicInfo } = app + const { systemFeatures } = useGlobalPublicStore() + const isTrialApp = app.can_trial && systemFeatures.enable_trial_app + const setShowTryAppPanel = useContextSelector(ExploreContext, ctx => ctx.setShowTryAppPanel) + const showTryAPPPanel = useCallback((appId: string) => { + return () => { + setShowTryAppPanel?.(true, { appId, app }) + } + }, [setShowTryAppPanel, app]) + return ( -
+
- {isExplore && canCreate && ( + {isExplore && (canCreate || isTrialApp) && ( )} diff --git a/web/app/components/explore/app-list/index.spec.tsx b/web/app/components/explore/app-list/index.spec.tsx index a9e4feeba8..a87d5a2363 100644 --- a/web/app/components/explore/app-list/index.spec.tsx +++ b/web/app/components/explore/app-list/index.spec.tsx @@ -16,9 +16,13 @@ let mockIsError = false const mockHandleImportDSL = vi.fn() const mockHandleImportDSLConfirm = vi.fn() -vi.mock('nuqs', () => ({ - useQueryState: () => [mockTabValue, mockSetTab], -})) +vi.mock('nuqs', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useQueryState: () => [mockTabValue, mockSetTab], + } +}) vi.mock('ahooks', async () => { const actual = await vi.importActual('ahooks') @@ -102,6 +106,7 @@ const createApp = (overrides: Partial = {}): App => ({ description: overrides.app?.description ?? 'Alpha description', use_icon_as_answer_icon: overrides.app?.use_icon_as_answer_icon ?? false, }, + can_trial: true, app_id: overrides.app_id ?? 'app-1', description: overrides.description ?? 'Alpha description', copyright: overrides.copyright ?? '', @@ -127,6 +132,8 @@ const renderWithContext = (hasEditPermission = false, onSuccess?: () => void) => setInstalledApps: vi.fn(), isFetchingInstalledApps: false, setIsFetchingInstalledApps: vi.fn(), + isShowTryAppPanel: false, + setShowTryAppPanel: vi.fn(), }} > diff --git a/web/app/components/explore/app-list/index.tsx b/web/app/components/explore/app-list/index.tsx index 5b318b780b..1749bde76a 100644 --- a/web/app/components/explore/app-list/index.tsx +++ b/web/app/components/explore/app-list/index.tsx @@ -7,14 +7,17 @@ import { useQueryState } from 'nuqs' import * as React from 'react' import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' +import { useContext, useContextSelector } from 'use-context-selector' import DSLConfirmModal from '@/app/components/app/create-from-dsl-modal/dsl-confirm-modal' +import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Loading from '@/app/components/base/loading' import AppCard from '@/app/components/explore/app-card' +import Banner from '@/app/components/explore/banner/banner' import Category from '@/app/components/explore/category' import CreateAppModal from '@/app/components/explore/create-app-modal' import ExploreContext from '@/context/explore-context' +import { useGlobalPublicStore } from '@/context/global-public-context' import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode, @@ -22,6 +25,7 @@ import { import { fetchAppDetail } from '@/service/explore' import { useExploreAppList } from '@/service/use-explore' import { cn } from '@/utils/classnames' +import TryApp from '../try-app' import s from './style.module.css' type AppsProps = { @@ -32,12 +36,19 @@ const Apps = ({ onSuccess, }: AppsProps) => { const { t } = useTranslation() + const { systemFeatures } = useGlobalPublicStore() const { hasEditPermission } = useContext(ExploreContext) const allCategoriesEn = t('apps.allCategories', { ns: 'explore', lng: 'en' }) const [keywords, setKeywords] = useState('') const [searchKeywords, setSearchKeywords] = useState('') + const hasFilterCondition = !!keywords + const handleResetFilter = useCallback(() => { + setKeywords('') + setSearchKeywords('') + }, []) + const { run: handleSearch } = useDebounceFn(() => { setSearchKeywords(keywords) }, { wait: 500 }) @@ -84,6 +95,18 @@ const Apps = ({ isFetching, } = useImportDSL() const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false) + + const isShowTryAppPanel = useContextSelector(ExploreContext, ctx => ctx.isShowTryAppPanel) + const setShowTryAppPanel = useContextSelector(ExploreContext, ctx => ctx.setShowTryAppPanel) + const hideTryAppPanel = useCallback(() => { + setShowTryAppPanel(false) + }, [setShowTryAppPanel]) + const appParams = useContextSelector(ExploreContext, ctx => ctx.currentApp) + const handleShowFromTryApp = useCallback(() => { + setCurrApp(appParams?.app || null) + setIsShowCreateModal(true) + }, [appParams?.app]) + const onCreate: CreateAppModalProps['onConfirm'] = async ({ name, icon_type, @@ -91,6 +114,8 @@ const Apps = ({ icon_background, description, }) => { + hideTryAppPanel() + const { export_data } = await fetchAppDetail( currApp?.app.id as string, ) @@ -137,22 +162,24 @@ const Apps = ({ 'flex h-full flex-col border-l-[0.5px] border-divider-regular', )} > - -
-
{t('apps.title', { ns: 'explore' })}
-
{t('apps.description', { ns: 'explore' })}
-
- + {systemFeatures.enable_explore_banner && ( +
+ +
+ )}
- +
+
{!hasFilterCondition ? t('apps.title', { ns: 'explore' }) : t('apps.resultNum', { num: searchFilteredList.length, ns: 'explore' })}
+ {hasFilterCondition && ( + <> +
+ + + )} +
+
+ +
+
) } + + {isShowTryAppPanel && ( + + )}
) } diff --git a/web/app/components/explore/banner/banner-item.tsx b/web/app/components/explore/banner/banner-item.tsx new file mode 100644 index 0000000000..5ce810bafb --- /dev/null +++ b/web/app/components/explore/banner/banner-item.tsx @@ -0,0 +1,187 @@ +/* eslint-disable react-hooks-extra/no-direct-set-state-in-use-effect */ +import type { FC } from 'react' +import type { Banner } from '@/models/app' +import { RiArrowRightLine } from '@remixicon/react' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useCarousel } from '@/app/components/base/carousel' +import { cn } from '@/utils/classnames' +import { IndicatorButton } from './indicator-button' + +type BannerItemProps = { + banner: Banner + autoplayDelay: number + isPaused?: boolean +} + +const RESPONSIVE_BREAKPOINT = 1200 +const MAX_RESPONSIVE_WIDTH = 600 +const INDICATOR_WIDTH = 20 +const INDICATOR_GAP = 8 +const MIN_VIEW_MORE_WIDTH = 480 + +export const BannerItem: FC = ({ banner, autoplayDelay, isPaused = false }) => { + const { t } = useTranslation() + const { api, selectedIndex } = useCarousel() + const { category, title, description, 'img-src': imgSrc } = banner.content + + const [resetKey, setResetKey] = useState(0) + const textAreaRef = useRef(null) + const [maxWidth, setMaxWidth] = useState(undefined) + + const slideInfo = useMemo(() => { + const slides = api?.slideNodes() ?? [] + const totalSlides = slides.length + const nextIndex = totalSlides > 0 ? (selectedIndex + 1) % totalSlides : 0 + return { slides, totalSlides, nextIndex } + }, [api, selectedIndex]) + + const indicatorsWidth = useMemo(() => { + const count = slideInfo.totalSlides + if (count === 0) + return 0 + // Calculate: indicator buttons + gaps + extra spacing (3 * 20px for divider and padding) + return (count + 2) * INDICATOR_WIDTH + (count - 1) * INDICATOR_GAP + }, [slideInfo.totalSlides]) + + const viewMoreStyle = useMemo(() => { + if (!maxWidth) + return undefined + return { + maxWidth: `${maxWidth}px`, + minWidth: indicatorsWidth ? `${Math.min(maxWidth - indicatorsWidth, MIN_VIEW_MORE_WIDTH)}px` : undefined, + } + }, [maxWidth, indicatorsWidth]) + + const responsiveStyle = useMemo( + () => (maxWidth !== undefined ? { maxWidth: `${maxWidth}px` } : undefined), + [maxWidth], + ) + + const incrementResetKey = useCallback(() => setResetKey(prev => prev + 1), []) + + useEffect(() => { + const updateMaxWidth = () => { + if (window.innerWidth < RESPONSIVE_BREAKPOINT && textAreaRef.current) { + const textAreaWidth = textAreaRef.current.offsetWidth + setMaxWidth(Math.min(textAreaWidth, MAX_RESPONSIVE_WIDTH)) + } + else { + setMaxWidth(undefined) + } + } + + updateMaxWidth() + + const resizeObserver = new ResizeObserver(updateMaxWidth) + if (textAreaRef.current) + resizeObserver.observe(textAreaRef.current) + + window.addEventListener('resize', updateMaxWidth) + + return () => { + resizeObserver.disconnect() + window.removeEventListener('resize', updateMaxWidth) + } + }, []) + + useEffect(() => { + incrementResetKey() + }, [selectedIndex, incrementResetKey]) + + const handleBannerClick = useCallback(() => { + incrementResetKey() + if (banner.link) + window.open(banner.link, '_blank', 'noopener,noreferrer') + }, [banner.link, incrementResetKey]) + + const handleIndicatorClick = useCallback((index: number) => { + incrementResetKey() + api?.scrollTo(index) + }, [api, incrementResetKey]) + + return ( +
+ {/* Left content area */} +
+
+ {/* Text section */} +
+ {/* Title area */} +
+

+ {category} +

+

+ {title} +

+
+ {/* Description area */} +
+

+ {description} +

+
+
+ + {/* Actions section */} +
+ {/* View more button */} +
+
+ +
+ + {t('banner.viewMore', { ns: 'explore' })} + +
+ +
+ {/* Slide navigation indicators */} +
+ {slideInfo.slides.map((_: unknown, index: number) => ( + handleIndicatorClick(index)} + /> + ))} +
+
+
+
+
+
+ + {/* Right image area */} +
+ {title} +
+
+ ) +} diff --git a/web/app/components/explore/banner/banner.tsx b/web/app/components/explore/banner/banner.tsx new file mode 100644 index 0000000000..4ec0efdb2b --- /dev/null +++ b/web/app/components/explore/banner/banner.tsx @@ -0,0 +1,94 @@ +import type { FC } from 'react' +import * as React from 'react' +import { useEffect, useMemo, useRef, useState } from 'react' +import { Carousel } from '@/app/components/base/carousel' +import { useLocale } from '@/context/i18n' +import { useGetBanners } from '@/service/use-explore' +import Loading from '../../base/loading' +import { BannerItem } from './banner-item' + +const AUTOPLAY_DELAY = 5000 +const MIN_LOADING_HEIGHT = 168 +const RESIZE_DEBOUNCE_DELAY = 50 + +const LoadingState: FC = () => ( +
+ +
+) + +const Banner: FC = () => { + const locale = useLocale() + const { data: banners, isLoading, isError } = useGetBanners(locale) + const [isHovered, setIsHovered] = useState(false) + const [isResizing, setIsResizing] = useState(false) + const resizeTimerRef = useRef(null) + + const enabledBanners = useMemo( + () => banners?.filter(banner => banner.status === 'enabled') ?? [], + [banners], + ) + + const isPaused = isHovered || isResizing + + // Handle window resize to pause animation + useEffect(() => { + const handleResize = () => { + setIsResizing(true) + + if (resizeTimerRef.current) + clearTimeout(resizeTimerRef.current) + + resizeTimerRef.current = setTimeout(() => { + setIsResizing(false) + }, RESIZE_DEBOUNCE_DELAY) + } + + window.addEventListener('resize', handleResize) + + return () => { + window.removeEventListener('resize', handleResize) + if (resizeTimerRef.current) + clearTimeout(resizeTimerRef.current) + } + }, []) + + if (isLoading) + return + + if (isError || enabledBanners.length === 0) + return null + + return ( + setIsHovered(true)} + onMouseLeave={() => setIsHovered(false)} + > + + {enabledBanners.map(banner => ( + + + + ))} + + + ) +} + +export default React.memo(Banner) diff --git a/web/app/components/explore/banner/indicator-button.tsx b/web/app/components/explore/banner/indicator-button.tsx new file mode 100644 index 0000000000..332dae53ba --- /dev/null +++ b/web/app/components/explore/banner/indicator-button.tsx @@ -0,0 +1,112 @@ +/* eslint-disable react-hooks-extra/no-direct-set-state-in-use-effect */ +import type { FC } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' +import { cn } from '@/utils/classnames' + +type IndicatorButtonProps = { + index: number + selectedIndex: number + isNextSlide: boolean + autoplayDelay: number + resetKey: number + isPaused?: boolean + onClick: () => void +} + +const PROGRESS_MAX = 100 +const DEGREES_PER_PERCENT = 3.6 + +export const IndicatorButton: FC = ({ + index, + selectedIndex, + isNextSlide, + autoplayDelay, + resetKey, + isPaused = false, + onClick, +}) => { + const [progress, setProgress] = useState(0) + const frameIdRef = useRef(undefined) + const startTimeRef = useRef(0) + + const isActive = index === selectedIndex + const shouldAnimate = !document.hidden && !isPaused + + useEffect(() => { + if (!isNextSlide) { + setProgress(0) + if (frameIdRef.current) + cancelAnimationFrame(frameIdRef.current) + return + } + + setProgress(0) + startTimeRef.current = Date.now() + + const animate = () => { + if (!document.hidden && !isPaused) { + const elapsed = Date.now() - startTimeRef.current + const newProgress = Math.min((elapsed / autoplayDelay) * PROGRESS_MAX, PROGRESS_MAX) + setProgress(newProgress) + + if (newProgress < PROGRESS_MAX) + frameIdRef.current = requestAnimationFrame(animate) + } + else { + frameIdRef.current = requestAnimationFrame(animate) + } + } + + if (shouldAnimate) + frameIdRef.current = requestAnimationFrame(animate) + + return () => { + if (frameIdRef.current) + cancelAnimationFrame(frameIdRef.current) + } + }, [isNextSlide, autoplayDelay, resetKey, isPaused]) + + const handleClick = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + onClick() + }, [onClick]) + + const progressDegrees = progress * DEGREES_PER_PERCENT + + return ( + + ) +} diff --git a/web/app/components/explore/category.tsx b/web/app/components/explore/category.tsx index 97a9ca92b3..47c2a4e3a7 100644 --- a/web/app/components/explore/category.tsx +++ b/web/app/components/explore/category.tsx @@ -29,7 +29,7 @@ const Category: FC = ({ const isAllCategories = !list.includes(value as AppCategory) || value === allCategoriesEn const itemClassName = (isSelected: boolean) => cn( - 'flex h-[32px] cursor-pointer items-center rounded-lg border-[0.5px] border-transparent px-3 py-[7px] font-medium leading-[18px] text-text-tertiary hover:bg-components-main-nav-nav-button-bg-active', + 'system-sm-medium flex h-7 cursor-pointer items-center rounded-lg border border-transparent px-3 text-text-tertiary hover:bg-components-main-nav-nav-button-bg-active', isSelected && 'border-components-main-nav-nav-button-border bg-components-main-nav-nav-button-bg-active text-components-main-nav-nav-button-text-active shadow-xs', ) diff --git a/web/app/components/explore/index.tsx b/web/app/components/explore/index.tsx index 30132eea66..0b5e18a1de 100644 --- a/web/app/components/explore/index.tsx +++ b/web/app/components/explore/index.tsx @@ -1,5 +1,6 @@ 'use client' import type { FC } from 'react' +import type { CurrentTryAppParams } from '@/context/explore-context' import type { InstalledApp } from '@/models/explore' import { useRouter } from 'next/navigation' import * as React from 'react' @@ -41,6 +42,16 @@ const Explore: FC = ({ return router.replace('/datasets') }, [isCurrentWorkspaceDatasetOperator]) + const [currentTryAppParams, setCurrentTryAppParams] = useState(undefined) + const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false) + const setShowTryAppPanel = (showTryAppPanel: boolean, params?: CurrentTryAppParams) => { + if (showTryAppPanel) + setCurrentTryAppParams(params) + else + setCurrentTryAppParams(undefined) + setIsShowTryAppPanel(showTryAppPanel) + } + return (
= ({ setInstalledApps, isFetchingInstalledApps, setIsFetchingInstalledApps, + currentApp: currentTryAppParams, + isShowTryAppPanel, + setShowTryAppPanel, } } > diff --git a/web/app/components/explore/installed-app/index.tsx b/web/app/components/explore/installed-app/index.tsx index def66c0260..7366057445 100644 --- a/web/app/components/explore/installed-app/index.tsx +++ b/web/app/components/explore/installed-app/index.tsx @@ -1,5 +1,6 @@ 'use client' import type { FC } from 'react' +import type { AccessMode } from '@/models/access-control' import type { AppData } from '@/models/share' import * as React from 'react' import { useEffect } from 'react' @@ -62,8 +63,8 @@ const InstalledApp: FC = ({ if (appMeta) updateWebAppMeta(appMeta) if (webAppAccessMode) - updateWebAppAccessMode(webAppAccessMode.accessMode) - updateUserCanAccessApp(Boolean(userCanAccessApp && userCanAccessApp?.result)) + updateWebAppAccessMode((webAppAccessMode as { accessMode: AccessMode }).accessMode) + updateUserCanAccessApp(Boolean(userCanAccessApp && (userCanAccessApp as { result: boolean })?.result)) }, [installedApp, appMeta, appParams, updateAppInfo, updateAppParams, updateUserCanAccessApp, updateWebAppMeta, userCanAccessApp, webAppAccessMode, updateWebAppAccessMode]) if (appParamsError) { diff --git a/web/app/components/explore/sidebar/app-nav-item/index.tsx b/web/app/components/explore/sidebar/app-nav-item/index.tsx index 3347efeb3f..08558578f6 100644 --- a/web/app/components/explore/sidebar/app-nav-item/index.tsx +++ b/web/app/components/explore/sidebar/app-nav-item/index.tsx @@ -56,7 +56,7 @@ export default function AppNavItem({ <>
-
{name}
+
{name}
e.stopPropagation()}> { setInstalledApps: vi.fn(), isFetchingInstalledApps: false, setIsFetchingInstalledApps: vi.fn(), - }} + } as unknown as IExplore} > , @@ -97,8 +98,8 @@ describe('SideBar', () => { renderWithContext(mockInstalledApps) // Assert - expect(screen.getByText('explore.sidebar.discovery')).toBeInTheDocument() - expect(screen.getByText('explore.sidebar.workspace')).toBeInTheDocument() + expect(screen.getByText('explore.sidebar.title')).toBeInTheDocument() + expect(screen.getByText('explore.sidebar.webApps')).toBeInTheDocument() expect(screen.getByText('My App')).toBeInTheDocument() }) }) diff --git a/web/app/components/explore/sidebar/index.tsx b/web/app/components/explore/sidebar/index.tsx index 1257886165..3e9b664580 100644 --- a/web/app/components/explore/sidebar/index.tsx +++ b/web/app/components/explore/sidebar/index.tsx @@ -1,5 +1,7 @@ 'use client' import type { FC } from 'react' +import { RiAppsFill, RiExpandRightLine, RiLayoutLeft2Line } from '@remixicon/react' +import { useBoolean } from 'ahooks' import Link from 'next/link' import { useSelectedLayoutSegments } from 'next/navigation' import * as React from 'react' @@ -14,18 +16,7 @@ import { useGetInstalledApps, useUninstallApp, useUpdateAppPinStatus } from '@/s import { cn } from '@/utils/classnames' import Toast from '../../base/toast' import Item from './app-nav-item' - -const SelectedDiscoveryIcon = () => ( - - - -) - -const DiscoveryIcon = () => ( - - - -) +import NoApps from './no-apps' export type IExploreSideBarProps = { controlUpdateInstalledApps: number @@ -45,6 +36,9 @@ const SideBar: FC = ({ const media = useBreakpoints() const isMobile = media === MediaType.mobile + const [isFold, { + toggle: toggleIsFold, + }] = useBoolean(false) const [showConfirm, setShowConfirm] = useState(false) const [currId, setCurrId] = useState('') @@ -84,22 +78,31 @@ const SideBar: FC = ({ const pinnedAppsCount = installedApps.filter(({ is_pinned }) => is_pinned).length return ( -
+
- {isDiscoverySelected ? : } - {!isMobile &&
{t('sidebar.discovery', { ns: 'explore' })}
} +
+ +
+ {!isMobile && !isFold &&
{t('sidebar.title', { ns: 'explore' })}
}
+ + {installedApps.length === 0 && !isMobile && !isFold + && ( +
+ +
+ )} + {installedApps.length > 0 && ( -
-

{t('sidebar.workspace', { ns: 'explore' })}

+
+ {!isMobile && !isFold &&

{t('sidebar.webApps', { ns: 'explore' })}

}
= ({ {installedApps.map(({ id, is_pinned, uninstallable, app: { name, icon_type, icon, icon_url, icon_background } }, index) => ( = ({
)} + + {!isMobile && ( +
+ {isFold + ? + : ( + + )} +
+ )} + {showConfirm && ( { + const { t } = useTranslation() + const { theme } = useTheme() + return ( +
+
+
{t(`${i18nPrefix}.title`, { ns: 'explore' })}
+
{t(`${i18nPrefix}.description`, { ns: 'explore' })}
+ {t(`${i18nPrefix}.learnMore`, { ns: 'explore' })} +
+ ) +} +export default React.memo(NoApps) diff --git a/web/app/components/explore/sidebar/no-apps/no-web-apps-dark.png b/web/app/components/explore/sidebar/no-apps/no-web-apps-dark.png new file mode 100644 index 0000000000..e153686fcd Binary files /dev/null and b/web/app/components/explore/sidebar/no-apps/no-web-apps-dark.png differ diff --git a/web/app/components/explore/sidebar/no-apps/no-web-apps-light.png b/web/app/components/explore/sidebar/no-apps/no-web-apps-light.png new file mode 100644 index 0000000000..2416b957d2 Binary files /dev/null and b/web/app/components/explore/sidebar/no-apps/no-web-apps-light.png differ diff --git a/web/app/components/explore/sidebar/no-apps/style.module.css b/web/app/components/explore/sidebar/no-apps/style.module.css new file mode 100644 index 0000000000..ad3787ce2b --- /dev/null +++ b/web/app/components/explore/sidebar/no-apps/style.module.css @@ -0,0 +1,7 @@ +.light { + background-image: url('./no-web-apps-light.png'); +} + +.dark { + background-image: url('./no-web-apps-dark.png'); +} diff --git a/web/app/components/explore/try-app/app-info/index.tsx b/web/app/components/explore/try-app/app-info/index.tsx new file mode 100644 index 0000000000..eab265bd04 --- /dev/null +++ b/web/app/components/explore/try-app/app-info/index.tsx @@ -0,0 +1,95 @@ +'use client' +import type { FC } from 'react' +import type { TryAppInfo } from '@/service/try-app' +import { RiAddLine } from '@remixicon/react' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import { AppTypeIcon } from '@/app/components/app/type-selector' +import AppIcon from '@/app/components/base/app-icon' +import Button from '@/app/components/base/button' +import { cn } from '@/utils/classnames' +import useGetRequirements from './use-get-requirements' + +type Props = { + appId: string + appDetail: TryAppInfo + category?: string + className?: string + onCreate: () => void +} + +const headerClassName = 'system-sm-semibold-uppercase text-text-secondary mb-3' + +const AppInfo: FC = ({ + appId, + className, + category, + appDetail, + onCreate, +}) => { + const { t } = useTranslation() + const mode = appDetail?.mode + const { requirements } = useGetRequirements({ appDetail, appId }) + return ( +
+ {/* name and icon */} +
+
+ + +
+
+
+
{appDetail.name}
+
+
+ {mode === 'advanced-chat' &&
{t('types.advanced', { ns: 'app' }).toUpperCase()}
} + {mode === 'chat' &&
{t('types.chatbot', { ns: 'app' }).toUpperCase()}
} + {mode === 'agent-chat' &&
{t('types.agent', { ns: 'app' }).toUpperCase()}
} + {mode === 'workflow' &&
{t('types.workflow', { ns: 'app' }).toUpperCase()}
} + {mode === 'completion' &&
{t('types.completion', { ns: 'app' }).toUpperCase()}
} +
+
+
+ {appDetail.description && ( +
{appDetail.description}
+ )} + + + {category && ( +
+
{t('tryApp.category', { ns: 'explore' })}
+
{category}
+
+ )} + {requirements.length > 0 && ( +
+
{t('tryApp.requirements', { ns: 'explore' })}
+
+ {requirements.map(item => ( +
+
+
{item.name}
+
+ ))} +
+
+ )} + +
+ ) +} +export default React.memo(AppInfo) diff --git a/web/app/components/explore/try-app/app-info/use-get-requirements.ts b/web/app/components/explore/try-app/app-info/use-get-requirements.ts new file mode 100644 index 0000000000..976989be73 --- /dev/null +++ b/web/app/components/explore/try-app/app-info/use-get-requirements.ts @@ -0,0 +1,78 @@ +import type { LLMNodeType } from '@/app/components/workflow/nodes/llm/types' +import type { ToolNodeType } from '@/app/components/workflow/nodes/tool/types' +import type { TryAppInfo } from '@/service/try-app' +import type { AgentTool } from '@/types/app' +import { uniqBy } from 'es-toolkit/compat' +import { BlockEnum } from '@/app/components/workflow/types' +import { MARKETPLACE_API_PREFIX } from '@/config' +import { useGetTryAppFlowPreview } from '@/service/use-try-app' + +type Params = { + appDetail: TryAppInfo + appId: string +} + +type RequirementItem = { + name: string + iconUrl: string +} +const getIconUrl = (provider: string, tool: string) => { + return `${MARKETPLACE_API_PREFIX}/plugins/${provider}/${tool}/icon` +} + +const useGetRequirements = ({ appDetail, appId }: Params) => { + const isBasic = ['chat', 'completion', 'agent-chat'].includes(appDetail.mode) + const isAgent = appDetail.mode === 'agent-chat' + const isAdvanced = !isBasic + const { data: flowData } = useGetTryAppFlowPreview(appId, isBasic) + + const requirements: RequirementItem[] = [] + if (isBasic) { + const modelProviderAndName = appDetail.model_config.model.provider.split('/') + const name = appDetail.model_config.model.provider.split('/').pop() || '' + requirements.push({ + name, + iconUrl: getIconUrl(modelProviderAndName[0], modelProviderAndName[1]), + }) + } + if (isAgent) { + requirements.push(...appDetail.model_config.agent_mode.tools.filter(data => (data as AgentTool).enabled).map((data) => { + const tool = data as AgentTool + const modelProviderAndName = tool.provider_id.split('/') + return { + name: tool.tool_label, + iconUrl: getIconUrl(modelProviderAndName[0], modelProviderAndName[1]), + } + })) + } + if (isAdvanced && flowData && flowData?.graph?.nodes?.length > 0) { + const nodes = flowData.graph.nodes + const llmNodes = nodes.filter(node => node.data.type === BlockEnum.LLM) + requirements.push(...llmNodes.map((node) => { + const data = node.data as LLMNodeType + const modelProviderAndName = data.model.provider.split('/') + return { + name: data.model.name, + iconUrl: getIconUrl(modelProviderAndName[0], modelProviderAndName[1]), + } + })) + + const toolNodes = nodes.filter(node => node.data.type === BlockEnum.Tool) + requirements.push(...toolNodes.map((node) => { + const data = node.data as ToolNodeType + const toolProviderAndName = data.provider_id.split('/') + return { + name: data.tool_label, + iconUrl: getIconUrl(toolProviderAndName[0], toolProviderAndName[1]), + } + })) + } + + const uniqueRequirements = uniqBy(requirements, 'name') + + return { + requirements: uniqueRequirements, + } +} + +export default useGetRequirements diff --git a/web/app/components/explore/try-app/app/chat.tsx b/web/app/components/explore/try-app/app/chat.tsx new file mode 100644 index 0000000000..b6b4a76ad5 --- /dev/null +++ b/web/app/components/explore/try-app/app/chat.tsx @@ -0,0 +1,104 @@ +'use client' +import type { FC } from 'react' +import type { + EmbeddedChatbotContextValue, +} from '@/app/components/base/chat/embedded-chatbot/context' +import type { TryAppInfo } from '@/service/try-app' +import { RiResetLeftLine } from '@remixicon/react' +import { useBoolean } from 'ahooks' +import * as React from 'react' +import { useEffect } from 'react' +import { useTranslation } from 'react-i18next' +import ActionButton from '@/app/components/base/action-button' +import Alert from '@/app/components/base/alert' +import AppIcon from '@/app/components/base/app-icon' +import ChatWrapper from '@/app/components/base/chat/embedded-chatbot/chat-wrapper' +import { + EmbeddedChatbotContext, +} from '@/app/components/base/chat/embedded-chatbot/context' +import { + useEmbeddedChatbot, +} from '@/app/components/base/chat/embedded-chatbot/hooks' +import ViewFormDropdown from '@/app/components/base/chat/embedded-chatbot/inputs-form/view-form-dropdown' +import Tooltip from '@/app/components/base/tooltip' +import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import { AppSourceType } from '@/service/share' +import { cn } from '@/utils/classnames' +import { useThemeContext } from '../../../base/chat/embedded-chatbot/theme/theme-context' + +type Props = { + appId: string + appDetail: TryAppInfo + className: string +} + +const TryApp: FC = ({ + appId, + appDetail, + className, +}) => { + const { t } = useTranslation() + const media = useBreakpoints() + const isMobile = media === MediaType.mobile + const themeBuilder = useThemeContext() + const { removeConversationIdInfo, ...chatData } = useEmbeddedChatbot(AppSourceType.tryApp, appId) + const currentConversationId = chatData.currentConversationId + const inputsForms = chatData.inputsForms + useEffect(() => { + if (appId) + removeConversationIdInfo(appId) + }, [appId]) + const [isHideTryNotice, { + setTrue: hideTryNotice, + }] = useBoolean(false) + + const handleNewConversation = () => { + removeConversationIdInfo(appId) + chatData.handleNewConversation() + } + return ( + +
+
+
+ +
{appDetail.name}
+
+
+ {currentConversationId && ( + + + + + + )} + {currentConversationId && inputsForms.length > 0 && ( + + )} +
+
+
+ {!isHideTryNotice && ( + + )} + +
+
+
+ ) +} +export default React.memo(TryApp) diff --git a/web/app/components/explore/try-app/app/index.tsx b/web/app/components/explore/try-app/app/index.tsx new file mode 100644 index 0000000000..f5dc14510d --- /dev/null +++ b/web/app/components/explore/try-app/app/index.tsx @@ -0,0 +1,44 @@ +'use client' +import type { FC } from 'react' +import type { AppData } from '@/models/share' +import type { TryAppInfo } from '@/service/try-app' +import * as React from 'react' +import useDocumentTitle from '@/hooks/use-document-title' +import Chat from './chat' +import TextGeneration from './text-generation' + +type Props = { + appId: string + appDetail: TryAppInfo +} + +const TryApp: FC = ({ + appId, + appDetail, +}) => { + const mode = appDetail?.mode + const isChat = ['chat', 'advanced-chat', 'agent-chat'].includes(mode!) + const isCompletion = !isChat + + useDocumentTitle(appDetail?.site?.title || '') + return ( +
+ {isChat && ( + + )} + {isCompletion && ( + + )} +
+ ) +} +export default React.memo(TryApp) diff --git a/web/app/components/explore/try-app/app/text-generation.tsx b/web/app/components/explore/try-app/app/text-generation.tsx new file mode 100644 index 0000000000..3189e621e9 --- /dev/null +++ b/web/app/components/explore/try-app/app/text-generation.tsx @@ -0,0 +1,262 @@ +'use client' +import type { FC } from 'react' +import type { InputValueTypes, Task } from '../../../share/text-generation/types' +import type { MoreLikeThisConfig, PromptConfig, TextToSpeechConfig } from '@/models/debug' +import type { AppData, CustomConfigValueType, SiteInfo } from '@/models/share' +import type { VisionFile, VisionSettings } from '@/types/app' +import { useBoolean } from 'ahooks' +import { noop } from 'es-toolkit/function' +import * as React from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Alert from '@/app/components/base/alert' +import AppIcon from '@/app/components/base/app-icon' +import Loading from '@/app/components/base/loading' +import Res from '@/app/components/share/text-generation/result' +import { TaskStatus } from '@/app/components/share/text-generation/types' +import { appDefaultIconBackground } from '@/config' +import { useWebAppStore } from '@/context/web-app-context' +import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import { AppSourceType } from '@/service/share' +import { useGetTryAppParams } from '@/service/use-try-app' +import { Resolution, TransferMethod } from '@/types/app' +import { cn } from '@/utils/classnames' +import { userInputsFormToPromptVariables } from '@/utils/model-config' +import RunOnce from '../../../share/text-generation/run-once' + +type Props = { + appId: string + className?: string + isWorkflow?: boolean + appData: AppData | null +} + +const TextGeneration: FC = ({ + appId, + className, + isWorkflow, + appData, +}) => { + const { t } = useTranslation() + const media = useBreakpoints() + const isPC = media === MediaType.pc + + const [inputs, doSetInputs] = useState>({}) + const inputsRef = useRef>(inputs) + const setInputs = useCallback((newInputs: Record) => { + doSetInputs(newInputs) + inputsRef.current = newInputs + }, []) + + const updateAppInfo = useWebAppStore(s => s.updateAppInfo) + const { data: tryAppParams } = useGetTryAppParams(appId) + + const updateAppParams = useWebAppStore(s => s.updateAppParams) + const appParams = useWebAppStore(s => s.appParams) + const [siteInfo, setSiteInfo] = useState(null) + const [promptConfig, setPromptConfig] = useState(null) + const [customConfig, setCustomConfig] = useState | null>(null) + const [moreLikeThisConfig, setMoreLikeThisConfig] = useState(null) + const [textToSpeechConfig, setTextToSpeechConfig] = useState(null) + const [controlSend, setControlSend] = useState(0) + const [visionConfig, setVisionConfig] = useState({ + enabled: false, + number_limits: 2, + detail: Resolution.low, + transfer_methods: [TransferMethod.local_file], + }) + const [completionFiles, setCompletionFiles] = useState([]) + const [isShowResultPanel, { setTrue: doShowResultPanel, setFalse: hideResultPanel }] = useBoolean(false) + const showResultPanel = () => { + // fix: useClickAway hideResSidebar will close sidebar + setTimeout(() => { + doShowResultPanel() + }, 0) + } + + const handleSend = () => { + setControlSend(Date.now()) + showResultPanel() + } + + const [resultExisted, setResultExisted] = useState(false) + + useEffect(() => { + if (!appData) + return + updateAppInfo(appData) + }, [appData, updateAppInfo]) + + useEffect(() => { + if (!tryAppParams) + return + updateAppParams(tryAppParams) + }, [tryAppParams, updateAppParams]) + + useEffect(() => { + (async () => { + if (!appData || !appParams) + return + const { site: siteInfo, custom_config } = appData + setSiteInfo(siteInfo as SiteInfo) + setCustomConfig(custom_config) + + const { user_input_form, more_like_this, file_upload, text_to_speech } = appParams + setVisionConfig({ + // legacy of image upload compatible + ...file_upload, + transfer_methods: file_upload?.allowed_file_upload_methods || file_upload?.allowed_upload_methods, + // legacy of image upload compatible + image_file_size_limit: appParams?.system_parameters.image_file_size_limit, + fileUploadConfig: appParams?.system_parameters, + // eslint-disable-next-line ts/no-explicit-any + } as any) + const prompt_variables = userInputsFormToPromptVariables(user_input_form) + setPromptConfig({ + prompt_template: '', // placeholder for future + prompt_variables, + } as PromptConfig) + setMoreLikeThisConfig(more_like_this) + setTextToSpeechConfig(text_to_speech) + })() + }, [appData, appParams]) + + const [isCompleted, setIsCompleted] = useState(false) + const handleCompleted = useCallback(() => { + setIsCompleted(true) + }, []) + const [isHideTryNotice, { + setTrue: hideTryNotice, + }] = useBoolean(false) + + const renderRes = (task?: Task) => ( + setResultExisted(true)} + /> + ) + + const renderResWrap = ( +
+
+ {isCompleted && !isHideTryNotice && ( + + )} + {renderRes()} +
+
+ ) + + if (!siteInfo || !promptConfig) { + return ( +
+ +
+ ) + } + + return ( +
+ {/* Left */} +
+ {/* Header */} +
+
+ +
{siteInfo.title}
+
+ {siteInfo.description && ( +
{siteInfo.description}
+ )} +
+ {/* form */} +
+ +
+
+ + {/* Result */} +
+ {!isPC && ( +
{ + if (isShowResultPanel) + hideResultPanel() + else + showResultPanel() + }} + > +
+
+ )} + {renderResWrap} +
+
+ ) +} + +export default React.memo(TextGeneration) diff --git a/web/app/components/explore/try-app/index.tsx b/web/app/components/explore/try-app/index.tsx new file mode 100644 index 0000000000..b2e2b72140 --- /dev/null +++ b/web/app/components/explore/try-app/index.tsx @@ -0,0 +1,74 @@ +/* eslint-disable style/multiline-ternary */ +'use client' +import type { FC } from 'react' +import { RiCloseLine } from '@remixicon/react' +import * as React from 'react' +import { useState } from 'react' +import Loading from '@/app/components/base/loading' +import Modal from '@/app/components/base/modal/index' +import { useGetTryAppInfo } from '@/service/use-try-app' +import Button from '../../base/button' +import App from './app' +import AppInfo from './app-info' +import Preview from './preview' +import Tab, { TypeEnum } from './tab' + +type Props = { + appId: string + category?: string + onClose: () => void + onCreate: () => void +} + +const TryApp: FC = ({ + appId, + category, + onClose, + onCreate, +}) => { + const [type, setType] = useState(TypeEnum.TRY) + const { data: appDetail, isLoading } = useGetTryAppInfo(appId) + + return ( + + {isLoading ? ( +
+ +
+ ) : ( +
+
+ + +
+ {/* Main content */} +
+ {type === TypeEnum.TRY ? : } + +
+
+ )} +
+ ) +} +export default React.memo(TryApp) diff --git a/web/app/components/explore/try-app/preview/basic-app-preview.tsx b/web/app/components/explore/try-app/preview/basic-app-preview.tsx new file mode 100644 index 0000000000..6954546b2e --- /dev/null +++ b/web/app/components/explore/try-app/preview/basic-app-preview.tsx @@ -0,0 +1,367 @@ +/* eslint-disable ts/no-explicit-any */ +'use client' +import type { FC } from 'react' +import type { Features as FeaturesData, FileUpload } from '@/app/components/base/features/types' +import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { ModelConfig } from '@/models/debug' +import type { ModelConfig as BackendModelConfig, PromptVariable } from '@/types/app' +import { noop } from 'es-toolkit/function' +import { clone } from 'es-toolkit/object' +import * as React from 'react' +import { useMemo, useState } from 'react' +import Config from '@/app/components/app/configuration/config' +import Debug from '@/app/components/app/configuration/debug' +import { FeaturesProvider } from '@/app/components/base/features' +import Loading from '@/app/components/base/loading' +import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' +import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' +import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config' +import ConfigContext from '@/context/debug-configuration' +import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import { PromptMode } from '@/models/debug' +import { useAllToolProviders } from '@/service/use-tools' +import { useGetTryAppDataSets, useGetTryAppInfo } from '@/service/use-try-app' +import { ModelModeType, Resolution, TransferMethod } from '@/types/app' +import { correctModelProvider, correctToolProvider } from '@/utils' +import { userInputsFormToPromptVariables } from '@/utils/model-config' +import { basePath } from '@/utils/var' +import { useTextGenerationCurrentProviderAndModelAndModelList } from '../../../header/account-setting/model-provider-page/hooks' + +type Props = { + appId: string +} + +const defaultModelConfig = { + provider: 'langgenius/openai/openai', + model_id: 'gpt-3.5-turbo', + mode: ModelModeType.unset, + configs: { + prompt_template: '', + prompt_variables: [] as PromptVariable[], + }, + more_like_this: null, + opening_statement: '', + suggested_questions: [], + sensitive_word_avoidance: null, + speech_to_text: null, + text_to_speech: null, + file_upload: null, + suggested_questions_after_answer: null, + retriever_resource: null, + annotation_reply: null, + dataSets: [], + agentConfig: DEFAULT_AGENT_SETTING, +} +const BasicAppPreview: FC = ({ + appId, +}) => { + const media = useBreakpoints() + const isMobile = media === MediaType.mobile + + const { data: appDetail, isLoading: isLoadingAppDetail } = useGetTryAppInfo(appId) + const { data: collectionListFromServer, isLoading: isLoadingToolProviders } = useAllToolProviders() + const collectionList = collectionListFromServer?.map((item) => { + return { + ...item, + icon: basePath && typeof item.icon == 'string' && !item.icon.includes(basePath) ? `${basePath}${item.icon}` : item.icon, + } + }) + const datasetIds = (() => { + if (isLoadingAppDetail) + return [] + const modelConfig = appDetail?.model_config + if (!modelConfig) + return [] + let datasets: any = null + + if (modelConfig.agent_mode?.tools?.find(({ dataset }: any) => dataset?.enabled)) + datasets = modelConfig.agent_mode?.tools.filter(({ dataset }: any) => dataset?.enabled) + // new dataset struct + else if (modelConfig.dataset_configs.datasets?.datasets?.length > 0) + datasets = modelConfig.dataset_configs?.datasets?.datasets + + if (datasets?.length && datasets?.length > 0) + return datasets.map(({ dataset }: any) => dataset.id) + + return [] + })() + const { data: dataSetData, isLoading: isLoadingDatasets } = useGetTryAppDataSets(appId, datasetIds) + const dataSets = dataSetData?.data || [] + const isLoading = isLoadingAppDetail || isLoadingDatasets || isLoadingToolProviders + + const modelConfig: ModelConfig = ((modelConfig?: BackendModelConfig) => { + if (isLoading || !modelConfig) + return defaultModelConfig + + const model = modelConfig.model + + const newModelConfig = { + provider: correctModelProvider(model.provider), + model_id: model.name, + mode: model.mode, + configs: { + prompt_template: modelConfig.pre_prompt || '', + prompt_variables: userInputsFormToPromptVariables( + [ + ...(modelConfig.user_input_form as any), + ...( + modelConfig.external_data_tools?.length + ? modelConfig.external_data_tools.map((item) => { + return { + external_data_tool: { + variable: item.variable as string, + label: item.label as string, + enabled: item.enabled, + type: item.type as string, + config: item.config, + required: true, + icon: item.icon, + icon_background: item.icon_background, + }, + } + }) + : [] + ), + ], + modelConfig.dataset_query_variable, + ), + }, + more_like_this: modelConfig.more_like_this, + opening_statement: modelConfig.opening_statement, + suggested_questions: modelConfig.suggested_questions, + sensitive_word_avoidance: modelConfig.sensitive_word_avoidance, + speech_to_text: modelConfig.speech_to_text, + text_to_speech: modelConfig.text_to_speech, + file_upload: modelConfig.file_upload, + suggested_questions_after_answer: modelConfig.suggested_questions_after_answer, + retriever_resource: modelConfig.retriever_resource, + annotation_reply: modelConfig.annotation_reply, + external_data_tools: modelConfig.external_data_tools, + dataSets, + agentConfig: appDetail?.mode === 'agent-chat' + // eslint-disable-next-line style/multiline-ternary + ? ({ + max_iteration: DEFAULT_AGENT_SETTING.max_iteration, + ...modelConfig.agent_mode, + // remove dataset + enabled: true, // modelConfig.agent_mode?.enabled is not correct. old app: the value of app with dataset's is always true + tools: modelConfig.agent_mode?.tools.filter((tool: any) => { + return !tool.dataset + }).map((tool: any) => { + const toolInCollectionList = collectionList?.find(c => tool.provider_id === c.id) + return { + ...tool, + isDeleted: appDetail?.deleted_tools?.some((deletedTool: any) => deletedTool.id === tool.id && deletedTool.tool_name === tool.tool_name), + notAuthor: toolInCollectionList?.is_team_authorization === false, + ...(tool.provider_type === 'builtin' + ? { + provider_id: correctToolProvider(tool.provider_name, !!toolInCollectionList), + provider_name: correctToolProvider(tool.provider_name, !!toolInCollectionList), + } + : {}), + } + }), + }) : DEFAULT_AGENT_SETTING, + } + return (newModelConfig as any) + })(appDetail?.model_config) + const mode = appDetail?.mode + // const isChatApp = ['chat', 'advanced-chat', 'agent-chat'].includes(mode!) + + // chat configuration + const promptMode = modelConfig?.prompt_type === PromptMode.advanced ? PromptMode.advanced : PromptMode.simple + const isAdvancedMode = promptMode === PromptMode.advanced + const isAgent = mode === 'agent-chat' + const chatPromptConfig = isAdvancedMode ? (modelConfig?.chat_prompt_config || clone(DEFAULT_CHAT_PROMPT_CONFIG)) : undefined + const suggestedQuestions = modelConfig?.suggested_questions || [] + const moreLikeThisConfig = modelConfig?.more_like_this || { enabled: false } + const suggestedQuestionsAfterAnswerConfig = modelConfig?.suggested_questions_after_answer || { enabled: false } + const speechToTextConfig = modelConfig?.speech_to_text || { enabled: false } + const textToSpeechConfig = modelConfig?.text_to_speech || { enabled: false, voice: '', language: '' } + const citationConfig = modelConfig?.retriever_resource || { enabled: false } + const annotationConfig = modelConfig?.annotation_reply || { + id: '', + enabled: false, + score_threshold: ANNOTATION_DEFAULT.score_threshold, + embedding_model: { + embedding_provider_name: '', + embedding_model_name: '', + }, + } + const moderationConfig = modelConfig?.sensitive_word_avoidance || { enabled: false } + // completion configuration + const completionPromptConfig = modelConfig?.completion_prompt_config || clone(DEFAULT_COMPLETION_PROMPT_CONFIG) as any + + // prompt & model config + const inputs = {} + const query = '' + const completionParams = useState({}) + + const { + currentModel: currModel, + } = useTextGenerationCurrentProviderAndModelAndModelList( + { + provider: modelConfig.provider, + model: modelConfig.model_id, + }, + ) + + const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision) + const isShowDocumentConfig = !!currModel?.features?.includes(ModelFeatureEnum.document) + const isShowAudioConfig = !!currModel?.features?.includes(ModelFeatureEnum.audio) + const isAllowVideoUpload = !!currModel?.features?.includes(ModelFeatureEnum.video) + const visionConfig = { + enabled: false, + number_limits: 2, + detail: Resolution.low, + transfer_methods: [TransferMethod.local_file], + } + + const featuresData: FeaturesData = useMemo(() => { + return { + moreLikeThis: modelConfig.more_like_this || { enabled: false }, + opening: { + enabled: !!modelConfig.opening_statement, + opening_statement: modelConfig.opening_statement || '', + suggested_questions: modelConfig.suggested_questions || [], + }, + moderation: modelConfig.sensitive_word_avoidance || { enabled: false }, + speech2text: modelConfig.speech_to_text || { enabled: false }, + text2speech: modelConfig.text_to_speech || { enabled: false }, + file: { + image: { + detail: modelConfig.file_upload?.image?.detail || Resolution.high, + enabled: !!modelConfig.file_upload?.image?.enabled, + number_limits: modelConfig.file_upload?.image?.number_limits || 3, + transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + }, + enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled), + allowed_file_types: modelConfig.file_upload?.allowed_file_types || [], + allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.video]].map(ext => `.${ext}`), + allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3, + fileUploadConfig: {}, + } as FileUpload, + suggested: modelConfig.suggested_questions_after_answer || { enabled: false }, + citation: modelConfig.retriever_resource || { enabled: false }, + annotationReply: modelConfig.annotation_reply || { enabled: false }, + } + }, [modelConfig]) + + if (isLoading) { + return ( +
+ +
+ ) + } + const value = { + readonly: true, + appId, + isAPIKeySet: true, + isTrailFinished: false, + mode, + modelModeType: '', + promptMode, + isAdvancedMode, + isAgent, + isOpenAI: false, + isFunctionCall: false, + collectionList: [], + setPromptMode: noop, + canReturnToSimpleMode: false, + setCanReturnToSimpleMode: noop, + chatPromptConfig, + completionPromptConfig, + currentAdvancedPrompt: '', + setCurrentAdvancedPrompt: noop, + conversationHistoriesRole: completionPromptConfig.conversation_histories_role, + showHistoryModal: false, + setConversationHistoriesRole: noop, + hasSetBlockStatus: true, + conversationId: '', + introduction: '', + setIntroduction: noop, + suggestedQuestions, + setSuggestedQuestions: noop, + setConversationId: noop, + controlClearChatMessage: false, + setControlClearChatMessage: noop, + prevPromptConfig: {}, + setPrevPromptConfig: noop, + moreLikeThisConfig, + setMoreLikeThisConfig: noop, + suggestedQuestionsAfterAnswerConfig, + setSuggestedQuestionsAfterAnswerConfig: noop, + speechToTextConfig, + setSpeechToTextConfig: noop, + textToSpeechConfig, + setTextToSpeechConfig: noop, + citationConfig, + setCitationConfig: noop, + annotationConfig, + setAnnotationConfig: noop, + moderationConfig, + setModerationConfig: noop, + externalDataToolsConfig: {}, + setExternalDataToolsConfig: noop, + formattingChanged: false, + setFormattingChanged: noop, + inputs, + setInputs: noop, + query, + setQuery: noop, + completionParams, + setCompletionParams: noop, + modelConfig, + setModelConfig: noop, + showSelectDataSet: noop, + dataSets, + setDataSets: noop, + datasetConfigs: [], + datasetConfigsRef: {}, + setDatasetConfigs: noop, + hasSetContextVar: true, + isShowVisionConfig, + visionConfig, + setVisionConfig: noop, + isAllowVideoUpload, + isShowDocumentConfig, + isShowAudioConfig, + rerankSettingModalOpen: false, + setRerankSettingModalOpen: noop, + } + return ( + + +
+
+
+ +
+ {!isMobile && ( +
+
+ +
+
+ )} +
+
+
+
+ ) +} +export default React.memo(BasicAppPreview) diff --git a/web/app/components/explore/try-app/preview/flow-app-preview.tsx b/web/app/components/explore/try-app/preview/flow-app-preview.tsx new file mode 100644 index 0000000000..ba64aecfba --- /dev/null +++ b/web/app/components/explore/try-app/preview/flow-app-preview.tsx @@ -0,0 +1,39 @@ +'use client' +import type { FC } from 'react' +import * as React from 'react' +import Loading from '@/app/components/base/loading' +import WorkflowPreview from '@/app/components/workflow/workflow-preview' +import { useGetTryAppFlowPreview } from '@/service/use-try-app' +import { cn } from '@/utils/classnames' + +type Props = { + appId: string + className?: string +} + +const FlowAppPreview: FC = ({ + appId, + className, +}) => { + const { data, isLoading } = useGetTryAppFlowPreview(appId) + + if (isLoading) { + return ( +
+ +
+ ) + } + if (!data) + return null + return ( +
+ +
+ ) +} +export default React.memo(FlowAppPreview) diff --git a/web/app/components/explore/try-app/preview/index.tsx b/web/app/components/explore/try-app/preview/index.tsx new file mode 100644 index 0000000000..a0c5fdc594 --- /dev/null +++ b/web/app/components/explore/try-app/preview/index.tsx @@ -0,0 +1,25 @@ +'use client' +import type { FC } from 'react' +import type { TryAppInfo } from '@/service/try-app' +import * as React from 'react' +import BasicAppPreview from './basic-app-preview' +import FlowAppPreview from './flow-app-preview' + +type Props = { + appId: string + appDetail: TryAppInfo +} + +const Preview: FC = ({ + appId, + appDetail, +}) => { + const isBasicApp = ['agent-chat', 'chat', 'completion'].includes(appDetail.mode) + + return ( +
+ {isBasicApp ? : } +
+ ) +} +export default React.memo(Preview) diff --git a/web/app/components/explore/try-app/tab.tsx b/web/app/components/explore/try-app/tab.tsx new file mode 100644 index 0000000000..75ba402204 --- /dev/null +++ b/web/app/components/explore/try-app/tab.tsx @@ -0,0 +1,37 @@ +'use client' +import type { FC } from 'react' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import TabHeader from '../../base/tab-header' + +export enum TypeEnum { + TRY = 'try', + DETAIL = 'detail', +} + +type Props = { + value: TypeEnum + onChange: (value: TypeEnum) => void +} + +const Tab: FC = ({ + value, + onChange, +}) => { + const { t } = useTranslation() + const tabs = [ + { id: TypeEnum.TRY, name: t('tryApp.tabHeader.try', { ns: 'explore' }) }, + { id: TypeEnum.DETAIL, name: t('tryApp.tabHeader.detail', { ns: 'explore' }) }, + ] + return ( + void} + itemClassName="ml-0 system-md-semibold-uppercase" + itemWrapClassName="pt-2" + activeItemClassName="border-util-colors-blue-brand-blue-brand-500" + /> + ) +} +export default React.memo(Tab) diff --git a/web/app/components/plugins/plugin-page/use-uploader.ts b/web/app/components/plugins/plugin-page/use-uploader.ts index 8c8b4a68c2..7df1cb95e3 100644 --- a/web/app/components/plugins/plugin-page/use-uploader.ts +++ b/web/app/components/plugins/plugin-page/use-uploader.ts @@ -36,7 +36,7 @@ export const useUploader = ({ onFileChange, containerRef, enabled = true }: Uplo setDragging(false) if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length > 0) onFileChange(files[0]) } diff --git a/web/app/components/provider/serwist.tsx b/web/app/components/provider/serwist.tsx index 39a80f5ac2..2eef43a7d6 100644 --- a/web/app/components/provider/serwist.tsx +++ b/web/app/components/provider/serwist.tsx @@ -1,3 +1,42 @@ 'use client' -export { SerwistProvider } from '@serwist/turbopack/react' +import { SerwistProvider } from '@serwist/turbopack/react' +import { useEffect } from 'react' +import { IS_DEV } from '@/config' +import { isClient } from '@/utils/client' + +export function PWAProvider({ children }: { children: React.ReactNode }) { + if (IS_DEV) { + return {children} + } + + const basePath = process.env.NEXT_PUBLIC_BASE_PATH || '' + const swUrl = `${basePath}/serwist/sw.js` + + return ( + + {children} + + ) +} + +function DisabledPWAProvider({ children }: { children: React.ReactNode }) { + useEffect(() => { + if (isClient && 'serviceWorker' in navigator) { + navigator.serviceWorker.getRegistrations() + .then((registrations) => { + registrations.forEach((registration) => { + registration.unregister() + .catch((error) => { + console.error('Error unregistering service worker:', error) + }) + }) + }) + .catch((error) => { + console.error('Error unregistering service workers:', error) + }) + } + }, []) + + return <>{children} +} diff --git a/web/app/components/share/text-generation/index.tsx b/web/app/components/share/text-generation/index.tsx index 509687e245..90a2fb9277 100644 --- a/web/app/components/share/text-generation/index.tsx +++ b/web/app/components/share/text-generation/index.tsx @@ -34,7 +34,7 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' import { changeLanguage } from '@/i18n-config/client' import { AccessMode } from '@/models/access-control' -import { fetchSavedMessage as doFetchSavedMessage, removeMessage, saveMessage } from '@/service/share' +import { AppSourceType, fetchSavedMessage as doFetchSavedMessage, removeMessage, saveMessage } from '@/service/share' import { Resolution, TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' import { userInputsFormToPromptVariables } from '@/utils/model-config' @@ -69,10 +69,10 @@ export type IMainProps = { const TextGeneration: FC = ({ isInstalledApp = false, - installedAppInfo, isWorkflow = false, }) => { const { notify } = Toast + const appSourceType = isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp const { t } = useTranslation() const media = useBreakpoints() @@ -102,16 +102,18 @@ const TextGeneration: FC = ({ // save message const [savedMessages, setSavedMessages] = useState([]) const fetchSavedMessage = useCallback(async () => { - const res: any = await doFetchSavedMessage(isInstalledApp, appId) + if (!appId) + return + const res: any = await doFetchSavedMessage(appSourceType, appId) setSavedMessages(res.data) - }, [isInstalledApp, appId]) + }, [appSourceType, appId]) const handleSaveMessage = async (messageId: string) => { - await saveMessage(messageId, isInstalledApp, appId) + await saveMessage(messageId, appSourceType, appId) notify({ type: 'success', message: t('api.saved', { ns: 'common' }) }) fetchSavedMessage() } const handleRemoveSavedMessage = async (messageId: string) => { - await removeMessage(messageId, isInstalledApp, appId) + await removeMessage(messageId, appSourceType, appId) notify({ type: 'success', message: t('api.remove', { ns: 'common' }) }) fetchSavedMessage() } @@ -423,9 +425,8 @@ const TextGeneration: FC = ({ isCallBatchAPI={isCallBatchAPI} isPC={isPC} isMobile={!isPC} - isInstalledApp={isInstalledApp} + appSourceType={isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp} appId={appId} - installedAppInfo={installedAppInfo} isError={task?.status === TaskStatus.failed} promptConfig={promptConfig} moreLikeThisEnabled={!!moreLikeThisConfig?.enabled} diff --git a/web/app/components/share/text-generation/result/index.tsx b/web/app/components/share/text-generation/result/index.tsx index a0ffb31b06..fe518c6d25 100644 --- a/web/app/components/share/text-generation/result/index.tsx +++ b/web/app/components/share/text-generation/result/index.tsx @@ -4,8 +4,8 @@ import type { FeedbackType } from '@/app/components/base/chat/chat/type' import type { WorkflowProcess } from '@/app/components/base/chat/types' import type { FileEntity } from '@/app/components/base/file-uploader/types' import type { PromptConfig } from '@/models/debug' -import type { InstalledApp } from '@/models/explore' import type { SiteInfo } from '@/models/share' +import type { AppSourceType } from '@/service/share' import type { VisionFile, VisionSettings } from '@/types/app' import { RiLoader2Line } from '@remixicon/react' import { useBoolean } from 'ahooks' @@ -35,9 +35,8 @@ export type IResultProps = { isCallBatchAPI: boolean isPC: boolean isMobile: boolean - isInstalledApp: boolean - appId: string - installedAppInfo?: InstalledApp + appSourceType: AppSourceType + appId?: string isError: boolean isShowTextToSpeech: boolean promptConfig: PromptConfig | null @@ -63,9 +62,8 @@ const Result: FC = ({ isCallBatchAPI, isPC, isMobile, - isInstalledApp, + appSourceType, appId, - installedAppInfo, isError, isShowTextToSpeech, promptConfig, @@ -133,7 +131,7 @@ const Result: FC = ({ }) const handleFeedback = async (feedback: FeedbackType) => { - await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating, content: feedback.content } }, isInstalledApp, installedAppInfo?.id) + await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating, content: feedback.content } }, appSourceType, appId) setFeedback(feedback) } @@ -147,9 +145,9 @@ const Result: FC = ({ setIsStopping(true) try { if (isWorkflow) - await stopWorkflowMessage(appId, currentTaskId, isInstalledApp, installedAppInfo?.id || '') + await stopWorkflowMessage(appId!, currentTaskId, appSourceType, appId || '') else - await stopChatMessageResponding(appId, currentTaskId, isInstalledApp, installedAppInfo?.id || '') + await stopChatMessageResponding(appId!, currentTaskId, appSourceType, appId || '') abortControllerRef.current?.abort() } catch (error) { @@ -159,7 +157,7 @@ const Result: FC = ({ finally { setIsStopping(false) } - }, [appId, currentTaskId, installedAppInfo?.id, isInstalledApp, isStopping, isWorkflow, notify]) + }, [appId, currentTaskId, appSourceType, appId, isStopping, isWorkflow, notify]) useEffect(() => { if (!onRunControlChange) @@ -468,8 +466,8 @@ const Result: FC = ({ })) }, }, - isInstalledApp, - installedAppInfo?.id, + appSourceType, + appId, ).catch((error) => { setRespondingFalse() resetRunState() @@ -514,7 +512,7 @@ const Result: FC = ({ getAbortController: (abortController) => { abortControllerRef.current = abortController }, - }, isInstalledApp, installedAppInfo?.id) + }, appSourceType, appId) } } @@ -562,8 +560,8 @@ const Result: FC = ({ feedback={feedback} onSave={handleSaveMessage} isMobile={isMobile} - isInstalledApp={isInstalledApp} - installedAppId={installedAppInfo?.id} + appSourceType={appSourceType} + installedAppId={appId} isLoading={isCallBatchAPI ? (!completionRes && isResponding) : false} taskId={isCallBatchAPI ? ((taskId as number) < 10 ? `0${taskId}` : `${taskId}`) : undefined} controlClearMoreLikeThis={controlClearMoreLikeThis} diff --git a/web/app/components/share/text-generation/run-once/index.tsx b/web/app/components/share/text-generation/run-once/index.tsx index ca29ce1a98..4531ff8beb 100644 --- a/web/app/components/share/text-generation/run-once/index.tsx +++ b/web/app/components/share/text-generation/run-once/index.tsx @@ -1,4 +1,5 @@ import type { ChangeEvent, FC, FormEvent } from 'react' +import type { InputValueTypes } from '../types' import type { PromptConfig } from '@/models/debug' import type { SiteInfo } from '@/models/share' import type { VisionFile, VisionSettings } from '@/types/app' @@ -25,9 +26,9 @@ import { cn } from '@/utils/classnames' export type IRunOnceProps = { siteInfo: SiteInfo promptConfig: PromptConfig - inputs: Record - inputsRef: React.RefObject> - onInputsChange: (inputs: Record) => void + inputs: Record + inputsRef: React.RefObject> + onInputsChange: (inputs: Record) => void onSend: () => void visionConfig: VisionSettings onVisionFilesChange: (files: VisionFile[]) => void @@ -52,7 +53,7 @@ const RunOnce: FC = ({ const [isInitialized, setIsInitialized] = useState(false) const onClear = () => { - const newInputs: Record = {} + const newInputs: Record = {} promptConfig.prompt_variables.forEach((item) => { if (item.type === 'string' || item.type === 'paragraph') newInputs[item.key] = '' @@ -127,7 +128,7 @@ const RunOnce: FC = ({ {item.type === 'select' && ( ) => { handleInputsChange({ ...inputsRef.current, [item.key]: e.target.value }) }} maxLength={item.max_length} /> @@ -146,7 +147,7 @@ const RunOnce: FC = ({