mirror of
https://github.com/langgenius/dify.git
synced 2026-01-21 20:45:22 +08:00
Compare commits
65 Commits
refactor/u
...
1.11.1
| Author | SHA1 | Date | |
|---|---|---|---|
| 2058186f22 | |||
| bece2f101c | |||
| 04d09c2d77 | |||
| db42f467c8 | |||
| ac40309850 | |||
| 12e39365fa | |||
| d48300d08c | |||
| 761f8c8043 | |||
| 05f63c88c6 | |||
| 8daf9ce98d | |||
| 61ee1b9094 | |||
| 87c4b4c576 | |||
| 193c8e2362 | |||
| 4d57460356 | |||
| 063b39ada5 | |||
| 6419ce02c7 | |||
| 1a877bb4d0 | |||
| 281e9d4f51 | |||
| a195b410d1 | |||
| 91e5db3e83 | |||
| f20a2d1586 | |||
| 6e802a343e | |||
| a30cbe3c95 | |||
| 7344adf65e | |||
| fcadee9413 | |||
| 69a22af1c9 | |||
| aac6f44562 | |||
| 2e1efd62e1 | |||
| 1847609926 | |||
| 91f6d25dae | |||
| acdbcdb6f8 | |||
| a9627ba60a | |||
| 266d1c70ac | |||
| d152d63e7d | |||
| b4afc7e435 | |||
| 2d496e7e08 | |||
| 693877e5e4 | |||
| 8cab3e5a1e | |||
| 18082752a0 | |||
| 813a734f27 | |||
| 94244ed8f6 | |||
| ec3a52f012 | |||
| ea063a1139 | |||
| 784008997b | |||
| 0c2a354115 | |||
| e477e6c928 | |||
| bafd093fa9 | |||
| 88b20bc6d0 | |||
| 12d019cd31 | |||
| b49e2646ff | |||
| e8720de9ad | |||
| 0867c1800b | |||
| 681c06186e | |||
| f722fdfa6d | |||
| c033030d8c | |||
| 51330c0ee6 | |||
| 7df360a292 | |||
| e205182e1f | |||
| 4a88c8fd19 | |||
| 1b9165624f | |||
| 56f8bdd724 | |||
| efa1b452da | |||
| bcbc07e99c | |||
| d79d0a47a7 | |||
| f5d676f3f1 |
8
.github/CODEOWNERS
vendored
8
.github/CODEOWNERS
vendored
@ -9,6 +9,14 @@
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
|
||||
# Backend - MCP
|
||||
api/core/mcp/ @Nov1c444
|
||||
api/core/entities/mcp_provider.py @Nov1c444
|
||||
api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
api/controllers/mcp/ @Nov1c444
|
||||
api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
api/tests/**/*mcp* @Nov1c444
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
|
||||
14
.github/ISSUE_TEMPLATE/refactor.yml
vendored
14
.github/ISSUE_TEMPLATE/refactor.yml
vendored
@ -1,8 +1,6 @@
|
||||
name: "✨ Refactor"
|
||||
description: Refactor existing code for improved readability and maintainability.
|
||||
title: "[Chore/Refactor] "
|
||||
labels:
|
||||
- refactor
|
||||
name: "✨ Refactor or Chore"
|
||||
description: Refactor existing code or perform maintenance chores to improve readability and reliability.
|
||||
title: "[Refactor/Chore] "
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
@ -11,7 +9,7 @@ body:
|
||||
options:
|
||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||
required: true
|
||||
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
- label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
@ -25,14 +23,14 @@ body:
|
||||
id: description
|
||||
attributes:
|
||||
label: Description
|
||||
placeholder: "Describe the refactor you are proposing."
|
||||
placeholder: "Describe the refactor or chore you are proposing."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation
|
||||
placeholder: "Explain why this refactor is necessary."
|
||||
placeholder: "Explain why this refactor or chore is necessary."
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
|
||||
13
.github/ISSUE_TEMPLATE/tracker.yml
vendored
13
.github/ISSUE_TEMPLATE/tracker.yml
vendored
@ -1,13 +0,0 @@
|
||||
name: "👾 Tracker"
|
||||
description: For inner usages, please do not use this template.
|
||||
title: "[Tracker] "
|
||||
labels:
|
||||
- tracker
|
||||
body:
|
||||
- type: textarea
|
||||
id: content
|
||||
attributes:
|
||||
label: Blockers
|
||||
placeholder: "- [ ] ..."
|
||||
validations:
|
||||
required: true
|
||||
21
.github/workflows/semantic-pull-request.yml
vendored
Normal file
21
.github/workflows/semantic-pull-request.yml
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
name: Semantic Pull Request
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- edited
|
||||
- reopened
|
||||
- synchronize
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Validate PR title
|
||||
permissions:
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check title
|
||||
uses: amannn/action-semantic-pull-request@v6.1.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -422,7 +422,6 @@ class DatasetApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_data = payload.model_dump(exclude_unset=True)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# check embedding model setting
|
||||
if (
|
||||
@ -434,6 +433,7 @@ class DatasetApi(Resource):
|
||||
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
|
||||
)
|
||||
payload.is_multimodal = is_multimodal
|
||||
payload_data = payload.model_dump(exclude_unset=True)
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, payload.permission, payload.partial_member_list
|
||||
|
||||
@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[Parser.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@ -52,10 +52,24 @@ class ChatMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
conversation_id: UUID | None = None
|
||||
parent_message_id: UUID | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = Field(default="explore_app")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
|
||||
"""
|
||||
Accept blank IDs and validate UUID format when provided.
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -30,9 +30,16 @@ class ConversationListQuery(BaseModel):
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource):
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
tenant_id=tenant_id, provider_name=provider
|
||||
)
|
||||
else:
|
||||
model_type = args.model_type
|
||||
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
|
||||
normalized_model_type = args.model_type.to_origin_model_type()
|
||||
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
|
||||
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
|
||||
@ -4,7 +4,7 @@ from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@ -52,11 +52,23 @@ class ChatRequestPayload(BaseModel):
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: UUID | None = None
|
||||
conversation_id: str | None = Field(default=None, description="Conversation UUID")
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
|
||||
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from uuid import UUID
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
@ -37,9 +37,16 @@ class ConversationListQuery(BaseModel):
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str = Field(description="New conversation name")
|
||||
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
|
||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||
|
||||
@ -62,8 +62,7 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
@ -73,7 +72,7 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -581,7 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
@ -591,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
# When hitting input-moderation or annotation-reply, the workflow will not start
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
self._save_message(session=session)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@ -600,7 +599,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
event: QueueAdvancedChatMessageEndEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
@ -618,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@ -772,13 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if self._conversation_name_generate_thread:
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
):
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
message = self._get_message(session=session)
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
@ -817,14 +809,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
metadata = self._task_state.metadata.model_dump()
|
||||
message.message_metadata = json.dumps(jsonable_encoder(metadata))
|
||||
|
||||
# Extract model provider and model_id from workflow node executions for tracing
|
||||
if message.workflow_run_id:
|
||||
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
|
||||
if model_info:
|
||||
message.model_provider = model_info.get("provider")
|
||||
message.model_id = model_info.get("model")
|
||||
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
@ -842,68 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Trigger MESSAGE_TRACE for tracing integrations
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
|
||||
"""
|
||||
Extract model provider and model_id from workflow node executions.
|
||||
Returns dict with 'provider' and 'model' keys, or None if not found.
|
||||
"""
|
||||
try:
|
||||
# Query workflow node executions for LLM or Agent nodes
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
|
||||
.order_by(WorkflowNodeExecutionModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
node_execution = session.scalar(stmt)
|
||||
|
||||
if not node_execution:
|
||||
return None
|
||||
|
||||
# Try to extract from execution_metadata for agent nodes
|
||||
if node_execution.execution_metadata:
|
||||
try:
|
||||
metadata = json.loads(node_execution.execution_metadata)
|
||||
agent_log = metadata.get("agent_log", [])
|
||||
# Look for the first agent thought with provider info
|
||||
for log_entry in agent_log:
|
||||
entry_metadata = log_entry.get("metadata", {})
|
||||
provider_str = entry_metadata.get("provider")
|
||||
if provider_str:
|
||||
# Parse format like "langgenius/deepseek/deepseek"
|
||||
parts = provider_str.split("/")
|
||||
if len(parts) >= 3:
|
||||
return {"provider": parts[1], "model": parts[2]}
|
||||
elif len(parts) == 2:
|
||||
return {"provider": parts[0], "model": parts[1]}
|
||||
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
||||
logger.debug("Failed to parse execution_metadata: %s", e)
|
||||
|
||||
# Try to extract from process_data for llm nodes
|
||||
if node_execution.process_data:
|
||||
try:
|
||||
process_data = json.loads(node_execution.process_data)
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider and model:
|
||||
return {"provider": provider, "model": model}
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.debug("Failed to parse process_data: %s", e)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract model info from workflow: %s", e)
|
||||
return None
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@ -40,9 +40,6 @@ class EasyUITaskState(TaskState):
|
||||
"""
|
||||
|
||||
llm_result: LLMResult
|
||||
first_token_time: float | None = None
|
||||
last_token_time: float | None = None
|
||||
is_streaming_response: bool = False
|
||||
|
||||
|
||||
class WorkflowTaskState(TaskState):
|
||||
|
||||
@ -332,12 +332,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# Track streaming response times
|
||||
if self._task_state.first_token_time is None:
|
||||
self._task_state.first_token_time = time.perf_counter()
|
||||
self._task_state.is_streaming_response = True
|
||||
self._task_state.last_token_time = time.perf_counter()
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text))
|
||||
if should_direct_answer:
|
||||
@ -404,18 +398,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.llm_result.usage.latency = message.provider_response_latency
|
||||
|
||||
# Add streaming metrics to usage if available
|
||||
if self._task_state.is_streaming_response and self._task_state.first_token_time:
|
||||
start_time = self.start_at
|
||||
first_token_time = self._task_state.first_token_time
|
||||
last_token_time = self._task_state.last_token_time or first_token_time
|
||||
usage.time_to_first_token = round(first_token_time - start_time, 3)
|
||||
usage.time_to_generate = round(last_token_time - first_token_time, 3)
|
||||
|
||||
# Update metadata with the complete usage info
|
||||
self._task_state.metadata.usage = usage
|
||||
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PreviewDetail(BaseModel):
|
||||
@ -20,7 +20,7 @@ class IndexingEstimate(BaseModel):
|
||||
class PipelineDataset(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
description: str | None = Field(default="", description="knowledge dataset description")
|
||||
chunk_structure: str
|
||||
|
||||
|
||||
|
||||
@ -213,12 +213,23 @@ class MCPProviderEntity(BaseModel):
|
||||
return None
|
||||
|
||||
def retrieve_tokens(self) -> OAuthTokens | None:
|
||||
"""OAuth tokens if available"""
|
||||
"""Retrieve OAuth tokens if authentication is complete.
|
||||
|
||||
Returns:
|
||||
OAuthTokens if the provider has been authenticated, None otherwise.
|
||||
"""
|
||||
if not self.credentials:
|
||||
return None
|
||||
credentials = self.decrypt_credentials()
|
||||
access_token = credentials.get("access_token", "")
|
||||
# Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header.
|
||||
# Note: We don't check for whitespace-only strings here because:
|
||||
# 1. OAuth servers don't return whitespace-only access tokens in practice
|
||||
# 2. Even if they did, the server would return 401, triggering the OAuth flow correctly
|
||||
if not access_token:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
access_token=access_token,
|
||||
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
|
||||
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
|
||||
@ -554,11 +554,16 @@ class LLMGenerator:
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_raw = cast(str, response.message.content)
|
||||
generated_raw = response.message.get_text_content()
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
return {**json.loads(generated_raw[first_brace : last_brace + 1])}
|
||||
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
|
||||
@ -222,59 +222,6 @@ class TencentSpanBuilder:
|
||||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_message_llm_span(
|
||||
trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str
|
||||
) -> SpanData:
|
||||
"""Build LLM span for message traces with detailed LLM attributes."""
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
# Extract model information from `metadata`` or `message_data`
|
||||
trace_metadata = trace_info.metadata or {}
|
||||
message_data = trace_info.message_data or {}
|
||||
|
||||
model_provider = trace_metadata.get("ls_provider") or (
|
||||
message_data.get("model_provider", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
model_name = trace_metadata.get("ls_model_name") or (
|
||||
message_data.get("model_id", "") if isinstance(message_data, dict) else ""
|
||||
)
|
||||
|
||||
inputs_str = str(trace_info.inputs or "")
|
||||
outputs_str = str(trace_info.outputs or "")
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: str(model_name),
|
||||
GEN_AI_PROVIDER: str(model_provider),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0),
|
||||
GEN_AI_PROMPT: inputs_str,
|
||||
GEN_AI_COMPLETION: outputs_str,
|
||||
INPUT_VALUE: inputs_str,
|
||||
OUTPUT_VALUE: outputs_str,
|
||||
}
|
||||
|
||||
if trace_info.is_streaming_request:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"),
|
||||
name="GENERATION",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes=attributes,
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||
"""Build tool span."""
|
||||
|
||||
@ -107,12 +107,8 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||
|
||||
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
# Add LLM child span with detailed attributes
|
||||
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||
llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id))
|
||||
self.trace_client.add_span(llm_span)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
self._record_message_llm_metrics(trace_info)
|
||||
|
||||
|
||||
@ -371,7 +371,7 @@ class RetrievalService:
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
segment_file_map = {}
|
||||
with Session(db.engine) as session:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Process documents
|
||||
for document in documents:
|
||||
segment_id = None
|
||||
@ -395,7 +395,7 @@ class RetrievalService:
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attchment_info"]
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
else:
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
@ -417,13 +417,6 @@ class RetrievalService:
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
.options(
|
||||
load_only(
|
||||
DocumentSegment.id,
|
||||
DocumentSegment.content,
|
||||
DocumentSegment.answer,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -458,12 +451,21 @@ class RetrievalService:
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
if segment.id in segment_child_map:
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
)
|
||||
else:
|
||||
segment_child_map[segment.id] = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
if attachment_info:
|
||||
segment_file_map[segment.id].append(attachment_info)
|
||||
if segment.id in segment_file_map:
|
||||
segment_file_map[segment.id].append(attachment_info)
|
||||
else:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
# Handle normal documents
|
||||
segment = None
|
||||
@ -475,7 +477,7 @@ class RetrievalService:
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attchment_info"]
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
@ -483,7 +485,7 @@ class RetrievalService:
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
segment = db.session.scalar(document_segment_stmt)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
if segment:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
@ -496,7 +498,7 @@ class RetrievalService:
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = db.session.scalar(document_segment_stmt)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
@ -684,7 +686,7 @@ class RetrievalService:
|
||||
.first()
|
||||
)
|
||||
if attachment_binding:
|
||||
attchment_info = {
|
||||
attachment_info = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
@ -692,5 +694,5 @@ class RetrievalService:
|
||||
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
|
||||
"size": upload_file.size,
|
||||
}
|
||||
return {"attchment_info": attchment_info, "segment_id": attachment_binding.segment_id}
|
||||
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
|
||||
return None
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook
|
||||
@ -10,6 +10,12 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class Candidate(TypedDict):
|
||||
idx: int
|
||||
count: int
|
||||
map: dict[int, str]
|
||||
|
||||
|
||||
class ExcelExtractor(BaseExtractor):
|
||||
"""Load Excel files.
|
||||
|
||||
@ -30,32 +36,38 @@ class ExcelExtractor(BaseExtractor):
|
||||
file_extension = os.path.splitext(self._file_path)[-1].lower()
|
||||
|
||||
if file_extension == ".xlsx":
|
||||
wb = load_workbook(self._file_path, data_only=True)
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
data = sheet.values
|
||||
cols = next(data, None)
|
||||
if cols is None:
|
||||
continue
|
||||
df = pd.DataFrame(data, columns=cols)
|
||||
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for index, row in df.iterrows():
|
||||
page_content = []
|
||||
for col_index, (k, v) in enumerate(row.items()):
|
||||
if pd.notna(v):
|
||||
cell = sheet.cell(
|
||||
row=cast(int, index) + 2, column=col_index + 1
|
||||
) # +2 to account for header and 1-based index
|
||||
if cell.hyperlink:
|
||||
value = f"[{v}]({cell.hyperlink.target})"
|
||||
page_content.append(f'"{k}":"{value}"')
|
||||
else:
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
wb = load_workbook(self._file_path, read_only=True, data_only=True)
|
||||
try:
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet)
|
||||
if not column_map:
|
||||
continue
|
||||
start_row = header_row_idx + 1
|
||||
for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False):
|
||||
if all(cell.value is None for cell in row):
|
||||
continue
|
||||
page_content = []
|
||||
for col_idx, cell in enumerate(row):
|
||||
value = cell.value
|
||||
if col_idx in column_map:
|
||||
col_name = column_map[col_idx]
|
||||
if hasattr(cell, "hyperlink") and cell.hyperlink:
|
||||
target = getattr(cell.hyperlink, "target", None)
|
||||
if target:
|
||||
value = f"[{value}]({target})"
|
||||
if value is None:
|
||||
value = ""
|
||||
elif not isinstance(value, str):
|
||||
value = str(value)
|
||||
value = value.strip().replace('"', '\\"')
|
||||
page_content.append(f'"{col_name}":"{value}"')
|
||||
if page_content:
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
finally:
|
||||
wb.close()
|
||||
|
||||
elif file_extension == ".xls":
|
||||
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
|
||||
@ -63,9 +75,9 @@ class ExcelExtractor(BaseExtractor):
|
||||
df = excel_file.parse(sheet_name=excel_sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
for _, series_row in df.iterrows():
|
||||
page_content = []
|
||||
for k, v in row.items():
|
||||
for k, v in series_row.items():
|
||||
if pd.notna(v):
|
||||
page_content.append(f'"{k}":"{v}"')
|
||||
documents.append(
|
||||
@ -75,3 +87,61 @@ class ExcelExtractor(BaseExtractor):
|
||||
raise ValueError(f"Unsupported file extension: {file_extension}")
|
||||
|
||||
return documents
|
||||
|
||||
def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]:
|
||||
"""
|
||||
Scan first N rows to find the most likely header row.
|
||||
Returns:
|
||||
header_row_idx: 1-based index of the header row
|
||||
column_map: Dict mapping 0-based column index to column name
|
||||
max_col_idx: 1-based index of the last valid column (for iter_rows boundary)
|
||||
"""
|
||||
# Store potential candidates: (row_index, non_empty_count, column_map)
|
||||
candidates: list[Candidate] = []
|
||||
|
||||
# Limit scan to avoid performance issues on huge files
|
||||
# We iterate manually to control the read scope
|
||||
for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1):
|
||||
# Filter out empty cells and build a temp map for this row
|
||||
# col_idx is 0-based
|
||||
row_map = {}
|
||||
for col_idx, cell_value in enumerate(row):
|
||||
if cell_value is not None and str(cell_value).strip():
|
||||
row_map[col_idx] = str(cell_value).strip().replace('"', '\\"')
|
||||
|
||||
if not row_map:
|
||||
continue
|
||||
|
||||
non_empty_count = len(row_map)
|
||||
|
||||
# Header selection heuristic (implemented):
|
||||
# - Prefer the first row with at least 2 non-empty columns.
|
||||
# - Fallback: choose the row with the most non-empty columns
|
||||
# (tie-breaker: smaller row index).
|
||||
candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map})
|
||||
|
||||
if not candidates:
|
||||
return 0, {}, 0
|
||||
|
||||
# Choose the best candidate header row.
|
||||
|
||||
best_candidate: Candidate | None = None
|
||||
|
||||
# Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback.
|
||||
|
||||
for cand in candidates:
|
||||
if cand["count"] >= 2:
|
||||
best_candidate = cand
|
||||
break
|
||||
|
||||
# Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns
|
||||
if not best_candidate:
|
||||
# Sort by count desc, then index asc
|
||||
candidates.sort(key=lambda x: (-x["count"], x["idx"]))
|
||||
best_candidate = candidates[0]
|
||||
|
||||
# Determine max_col_idx (1-based for openpyxl)
|
||||
# It is the index of the last valid column in our map + 1
|
||||
max_col_idx = max(best_candidate["map"].keys()) + 1
|
||||
|
||||
return best_candidate["idx"], best_candidate["map"], max_col_idx
|
||||
|
||||
@ -84,22 +84,46 @@ class WordExtractor(BaseExtractor):
|
||||
image_count = 0
|
||||
image_map = {}
|
||||
|
||||
for rel in doc.part.rels.values():
|
||||
for rId, rel in doc.part.rels.items():
|
||||
if "image" in rel.target_ref:
|
||||
image_count += 1
|
||||
if rel.is_external:
|
||||
url = rel.target_ref
|
||||
response = ssrf_proxy.get(url)
|
||||
if not self._is_valid_url(url):
|
||||
continue
|
||||
try:
|
||||
response = ssrf_proxy.get(url)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download image from URL: %s: %s", url, str(e))
|
||||
continue
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", ""))
|
||||
if image_ext is None:
|
||||
continue
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
storage.save(file_key, response.content)
|
||||
else:
|
||||
continue
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
# Use rId as key for external images since target_part is undefined
|
||||
image_map[rId] = f""
|
||||
else:
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
if image_ext is None:
|
||||
@ -110,26 +134,28 @@ class WordExtractor(BaseExtractor):
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
|
||||
storage.save(file_key, rel.target_part.blob)
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
image_map[rel.target_part] = f""
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
extension=str(image_ext),
|
||||
mime_type=mime_type or "",
|
||||
created_by=self.user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self.user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
# Use target_part as key for internal images
|
||||
image_map[rel.target_part] = (
|
||||
f""
|
||||
)
|
||||
|
||||
return image_map
|
||||
|
||||
@ -186,11 +212,17 @@ class WordExtractor(BaseExtractor):
|
||||
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if not image_id:
|
||||
continue
|
||||
image_part = paragraph.part.rels[image_id].target_part
|
||||
|
||||
if image_part in image_map:
|
||||
image_link = image_map[image_part]
|
||||
paragraph_content.append(image_link)
|
||||
rel = paragraph.part.rels.get(image_id)
|
||||
if rel is None:
|
||||
continue
|
||||
# For external images, use image_id as key; for internal, use target_part
|
||||
if rel.is_external:
|
||||
if image_id in image_map:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
else:
|
||||
paragraph_content.append(run.text)
|
||||
return "".join(paragraph_content).strip()
|
||||
@ -227,6 +259,18 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
def parse_paragraph(paragraph):
|
||||
paragraph_content = []
|
||||
|
||||
def append_image_link(image_id, has_drawing):
|
||||
"""Helper to append image link from image_map based on relationship type."""
|
||||
rel = doc.part.rels[image_id]
|
||||
if rel.is_external:
|
||||
if image_id in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
|
||||
for run in paragraph.runs:
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
# Process drawing type images
|
||||
@ -243,10 +287,18 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
|
||||
)
|
||||
if embed_id:
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
rel = doc.part.rels.get(embed_id)
|
||||
if rel is not None and rel.is_external:
|
||||
# External image: use embed_id as key
|
||||
if embed_id in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[embed_id])
|
||||
else:
|
||||
# Internal image: use target_part as key
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
# Process pict type images
|
||||
shape_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
||||
@ -261,9 +313,7 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
image_part = doc.part.rels[image_id].target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
append_image_link(image_id, has_drawing)
|
||||
# Find imagedata element in VML
|
||||
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
||||
if image_data is not None:
|
||||
@ -271,9 +321,7 @@ class WordExtractor(BaseExtractor):
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
image_part = doc.part.rels[image_id].target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
append_image_link(image_id, has_drawing)
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return "".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
@ -209,7 +209,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if all_multimodal_documents:
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
elif dataset.indexing_technique == "economy":
|
||||
keyword = Keyword(dataset)
|
||||
|
||||
@ -312,7 +312,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
vector = Vector(dataset)
|
||||
if all_child_documents:
|
||||
vector.create(all_child_documents)
|
||||
if all_multimodal_documents:
|
||||
if all_multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(all_multimodal_documents)
|
||||
|
||||
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
|
||||
|
||||
@ -266,7 +266,7 @@ class DatasetRetrieval:
|
||||
).all()
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attchment_info = File(
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
@ -280,7 +280,7 @@ class DatasetRetrieval:
|
||||
storage_key=upload_file.key,
|
||||
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||
)
|
||||
context_files.append(attchment_info)
|
||||
context_files.append(attachment_info)
|
||||
if show_retrieve_source:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
@ -592,111 +592,116 @@ class DatasetRetrieval:
|
||||
"""Handle retrieval end."""
|
||||
with flask_app.app_context():
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
segment_ids = []
|
||||
segment_index_node_ids = []
|
||||
if not dify_documents:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
with Session(db.engine) as session:
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
)
|
||||
dataset_document = session.scalar(dataset_document_stmt)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
segment_id = None
|
||||
if (
|
||||
"doc_type" not in document.metadata
|
||||
or document.metadata.get("doc_type") == DocType.TEXT
|
||||
):
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment_id = child_chunk.segment_id
|
||||
elif (
|
||||
"doc_type" in document.metadata
|
||||
and document.metadata.get("doc_type") == DocType.IMAGE
|
||||
):
|
||||
attachment_info_dict = RetrievalService.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
# Collect all document_ids and batch fetch DatasetDocuments
|
||||
document_ids = {
|
||||
doc.metadata["document_id"]
|
||||
for doc in dify_documents
|
||||
if doc.metadata and "document_id" in doc.metadata
|
||||
}
|
||||
if not document_ids:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
dataset_docs_stmt = select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
|
||||
dataset_docs = session.scalars(dataset_docs_stmt).all()
|
||||
dataset_doc_map = {str(doc.id): doc for doc in dataset_docs}
|
||||
|
||||
# Categorize documents by type and collect necessary IDs
|
||||
parent_child_text_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
parent_child_image_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
normal_text_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
normal_image_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
|
||||
for doc in dify_documents:
|
||||
if not doc.metadata or "document_id" not in doc.metadata:
|
||||
continue
|
||||
dataset_doc = dataset_doc_map.get(doc.metadata["document_id"])
|
||||
if not dataset_doc:
|
||||
continue
|
||||
|
||||
is_image = doc.metadata.get("doc_type") == DocType.IMAGE
|
||||
is_parent_child = dataset_doc.doc_form == IndexStructureType.PARENT_CHILD_INDEX
|
||||
|
||||
if is_parent_child:
|
||||
if is_image:
|
||||
parent_child_image_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
parent_child_text_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
if is_image:
|
||||
normal_image_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
normal_text_docs.append((doc, dataset_doc))
|
||||
|
||||
segment_ids_to_update: set[str] = set()
|
||||
|
||||
# Process PARENT_CHILD_INDEX text documents - batch fetch ChildChunks
|
||||
if parent_child_text_docs:
|
||||
index_node_ids = [doc.metadata["doc_id"] for doc, _ in parent_child_text_docs if doc.metadata]
|
||||
if index_node_ids:
|
||||
child_chunks_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(index_node_ids))
|
||||
child_chunks = session.scalars(child_chunks_stmt).all()
|
||||
child_chunk_map = {chunk.index_node_id: chunk.segment_id for chunk in child_chunks}
|
||||
for doc, _ in parent_child_text_docs:
|
||||
if doc.metadata:
|
||||
segment_id = child_chunk_map.get(doc.metadata["doc_id"])
|
||||
if segment_id:
|
||||
if segment_id not in segment_ids:
|
||||
segment_ids.append(segment_id)
|
||||
_ = (
|
||||
session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = None
|
||||
if (
|
||||
"doc_type" not in document.metadata
|
||||
or document.metadata.get("doc_type") == DocType.TEXT
|
||||
):
|
||||
if document.metadata["doc_id"] not in segment_index_node_ids:
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
segment_index_node_ids.append(document.metadata["doc_id"])
|
||||
segment_ids.append(segment.id)
|
||||
query = session.query(DocumentSegment).where(
|
||||
DocumentSegment.id == segment.id
|
||||
)
|
||||
elif (
|
||||
"doc_type" in document.metadata
|
||||
and document.metadata.get("doc_type") == DocType.IMAGE
|
||||
):
|
||||
attachment_info_dict = RetrievalService.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
if segment_id not in segment_ids:
|
||||
segment_ids.append(segment_id)
|
||||
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
|
||||
if query:
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.where(
|
||||
DocumentSegment.dataset_id == document.metadata["dataset_id"]
|
||||
)
|
||||
segment_ids_to_update.add(str(segment_id))
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
# Process non-PARENT_CHILD_INDEX text documents - batch fetch DocumentSegments
|
||||
if normal_text_docs:
|
||||
index_node_ids = [doc.metadata["doc_id"] for doc, _ in normal_text_docs if doc.metadata]
|
||||
if index_node_ids:
|
||||
segments_stmt = select(DocumentSegment).where(DocumentSegment.index_node_id.in_(index_node_ids))
|
||||
segments = session.scalars(segments_stmt).all()
|
||||
segment_map = {seg.index_node_id: seg.id for seg in segments}
|
||||
for doc, _ in normal_text_docs:
|
||||
if doc.metadata:
|
||||
segment_id = segment_map.get(doc.metadata["doc_id"])
|
||||
if segment_id:
|
||||
segment_ids_to_update.add(str(segment_id))
|
||||
|
||||
db.session.commit()
|
||||
# Process IMAGE documents - batch fetch SegmentAttachmentBindings
|
||||
all_image_docs = parent_child_image_docs + normal_image_docs
|
||||
if all_image_docs:
|
||||
attachment_ids = [
|
||||
doc.metadata["doc_id"]
|
||||
for doc, _ in all_image_docs
|
||||
if doc.metadata and doc.metadata.get("doc_id")
|
||||
]
|
||||
if attachment_ids:
|
||||
bindings_stmt = select(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.attachment_id.in_(attachment_ids)
|
||||
)
|
||||
bindings = session.scalars(bindings_stmt).all()
|
||||
segment_ids_to_update.update(str(binding.segment_id) for binding in bindings)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
# Batch update hit_count for all segments
|
||||
if segment_ids_to_update:
|
||||
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
|
||||
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
|
||||
"""Send trace task if trace manager is available."""
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
def _on_query(
|
||||
self,
|
||||
|
||||
@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
|
||||
return re.sub(pattern, "", text)
|
||||
|
||||
@ -221,7 +221,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
WorkflowToolProvider.id == self.provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -59,7 +59,7 @@ class OutputVariableEntity(BaseModel):
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_type: OutputVariableType
|
||||
value_type: OutputVariableType = OutputVariableType.ANY
|
||||
value_selector: Sequence[str]
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
|
||||
@ -412,16 +412,20 @@ class Executor:
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
# decode content safely
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
body_string += content.decode("utf-8", errors="replace")
|
||||
body_string += "\r\n"
|
||||
# Do not decode binary content; use a placeholder with file metadata instead.
|
||||
# Includes filename, size, and MIME type for better logging context.
|
||||
body_string += (
|
||||
f"<file_content_binary: '{file_entry[1][0] or 'unknown'}', "
|
||||
f"type='{file_entry[1][2] if len(file_entry[1]) > 2 else 'unknown'}', "
|
||||
f"size={len(content)} bytes>\r\n"
|
||||
)
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
if self.content:
|
||||
# If content is bytes, do not decode it; show a placeholder with size.
|
||||
# Provides content size information for binary data without exposing the raw bytes.
|
||||
if isinstance(self.content, bytes):
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
body_string = f"<binary_content: size={len(self.content)} bytes>"
|
||||
else:
|
||||
body_string = self.content
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
|
||||
@ -334,6 +334,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@ -344,6 +345,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
||||
@ -694,7 +697,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
).all()
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attchment_info = File(
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
@ -708,7 +711,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
storage_key=upload_file.key,
|
||||
url=sign_upload_file(upload_file.id, upload_file.extension),
|
||||
)
|
||||
context_files.append(attchment_info)
|
||||
context_files.append(attachment_info)
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=original_retriever_resource,
|
||||
context=context_str.strip(),
|
||||
|
||||
@ -221,6 +221,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
|
||||
@ -107,7 +107,7 @@ def email(email):
|
||||
EmailStr = Annotated[str, AfterValidator(email)]
|
||||
|
||||
|
||||
def uuid_value(value):
|
||||
def uuid_value(value: Any) -> str:
|
||||
if value == "":
|
||||
return str(value)
|
||||
|
||||
@ -215,7 +215,11 @@ def generate_text_hash(text: str) -> str:
|
||||
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
|
||||
return Response(
|
||||
response=json.dumps(jsonable_encoder(response)),
|
||||
status=200,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
else:
|
||||
|
||||
def generate() -> Generator:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
"""empty message
|
||||
"""mysql adaptation
|
||||
|
||||
Revision ID: 09cfdda155d1
|
||||
Revises: 669ffd70119c
|
||||
@ -97,11 +97,31 @@ def downgrade():
|
||||
batch_op.alter_column('include_plugins',
|
||||
existing_type=sa.JSON(),
|
||||
type_=postgresql.ARRAY(sa.VARCHAR(length=255)),
|
||||
existing_nullable=False)
|
||||
existing_nullable=False,
|
||||
postgresql_using="""
|
||||
COALESCE(
|
||||
regexp_replace(
|
||||
replace(replace(include_plugins::text, '[', '{'), ']', '}'),
|
||||
'"',
|
||||
'',
|
||||
'g'
|
||||
)::varchar(255)[],
|
||||
ARRAY[]::varchar(255)[]
|
||||
)""")
|
||||
batch_op.alter_column('exclude_plugins',
|
||||
existing_type=sa.JSON(),
|
||||
type_=postgresql.ARRAY(sa.VARCHAR(length=255)),
|
||||
existing_nullable=False)
|
||||
existing_nullable=False,
|
||||
postgresql_using="""
|
||||
COALESCE(
|
||||
regexp_replace(
|
||||
replace(replace(exclude_plugins::text, '[', '{'), ']', '}'),
|
||||
'"',
|
||||
'',
|
||||
'g'
|
||||
)::varchar(255)[],
|
||||
ARRAY[]::varchar(255)[]
|
||||
)""")
|
||||
|
||||
with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op:
|
||||
batch_op.alter_column('external_knowledge_id',
|
||||
|
||||
@ -78,7 +78,7 @@ class Dataset(Base):
|
||||
pipeline_id = mapped_column(StringUUID, nullable=True)
|
||||
chunk_structure = mapped_column(sa.String(255), nullable=True)
|
||||
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
is_multimodal = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_multimodal = mapped_column(sa.Boolean, default=False, nullable=False, server_default=db.text("false"))
|
||||
|
||||
@property
|
||||
def total_documents(self):
|
||||
|
||||
@ -111,7 +111,11 @@ class App(Base):
|
||||
else:
|
||||
app_model_config = self.app_model_config
|
||||
if app_model_config:
|
||||
return app_model_config.pre_prompt
|
||||
pre_prompt = app_model_config.pre_prompt or ""
|
||||
# Truncate to 200 characters with ellipsis if using prompt as description
|
||||
if len(pre_prompt) > 200:
|
||||
return pre_prompt[:200] + "..."
|
||||
return pre_prompt
|
||||
else:
|
||||
return ""
|
||||
|
||||
@ -259,7 +263,7 @@ class App(Base):
|
||||
provider_id = tool.get("provider_id", "")
|
||||
|
||||
if provider_type == ToolProviderType.API:
|
||||
if uuid.UUID(provider_id) not in existing_api_providers:
|
||||
if provider_id not in existing_api_providers:
|
||||
deleted_tools.append(
|
||||
{
|
||||
"type": ToolProviderType.API,
|
||||
@ -835,7 +839,29 @@ class Conversation(Base):
|
||||
|
||||
@property
|
||||
def status_count(self):
|
||||
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
# Get all messages with workflow_run_id for this conversation
|
||||
messages = db.session.scalars(
|
||||
select(Message).where(Message.conversation_id == self.id, Message.workflow_run_id.isnot(None))
|
||||
).all()
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Batch load all workflow runs in a single query, filtered by this conversation's app_id
|
||||
workflow_run_ids = [msg.workflow_run_id for msg in messages if msg.workflow_run_id]
|
||||
workflow_runs = {}
|
||||
|
||||
if workflow_run_ids:
|
||||
workflow_runs_query = db.session.scalars(
|
||||
select(WorkflowRun).where(
|
||||
WorkflowRun.id.in_(workflow_run_ids),
|
||||
WorkflowRun.app_id == self.app_id, # Filter by this conversation's app_id
|
||||
)
|
||||
).all()
|
||||
workflow_runs = {run.id: run for run in workflow_runs_query}
|
||||
|
||||
status_counts = {
|
||||
WorkflowExecutionStatus.RUNNING: 0,
|
||||
WorkflowExecutionStatus.SUCCEEDED: 0,
|
||||
@ -845,18 +871,24 @@ class Conversation(Base):
|
||||
}
|
||||
|
||||
for message in messages:
|
||||
if message.workflow_run:
|
||||
status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1
|
||||
# Guard against None to satisfy type checker and avoid invalid dict lookups
|
||||
if message.workflow_run_id is None:
|
||||
continue
|
||||
workflow_run = workflow_runs.get(message.workflow_run_id)
|
||||
if not workflow_run:
|
||||
continue
|
||||
|
||||
return (
|
||||
{
|
||||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||
}
|
||||
if messages
|
||||
else None
|
||||
)
|
||||
try:
|
||||
status_counts[WorkflowExecutionStatus(workflow_run.status)] += 1
|
||||
except (ValueError, KeyError):
|
||||
# Handle invalid status values gracefully
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||
}
|
||||
|
||||
@property
|
||||
def first_message(self):
|
||||
@ -1255,13 +1287,9 @@ class Message(Base):
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"model_provider": self.model_provider,
|
||||
"model_id": self.model_id,
|
||||
"inputs": self.inputs,
|
||||
"query": self.query,
|
||||
"message_tokens": self.message_tokens,
|
||||
"answer_tokens": self.answer_tokens,
|
||||
"provider_response_latency": self.provider_response_latency,
|
||||
"total_price": self.total_price,
|
||||
"message": self.message,
|
||||
"answer": self.answer,
|
||||
@ -1283,12 +1311,8 @@ class Message(Base):
|
||||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
conversation_id=data["conversation_id"],
|
||||
model_provider=data.get("model_provider"),
|
||||
model_id=data["model_id"],
|
||||
inputs=data["inputs"],
|
||||
message_tokens=data.get("message_tokens", 0),
|
||||
answer_tokens=data.get("answer_tokens", 0),
|
||||
provider_response_latency=data.get("provider_response_latency", 0.0),
|
||||
total_price=data["total_price"],
|
||||
query=data["query"],
|
||||
message=data["message"],
|
||||
|
||||
@ -907,19 +907,29 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
@property
|
||||
def extras(self) -> dict[str, Any]:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
|
||||
extras: dict[str, Any] = {}
|
||||
if self.execution_metadata_dict:
|
||||
if self.node_type == NodeType.TOOL and "tool_info" in self.execution_metadata_dict:
|
||||
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
|
||||
execution_metadata = self.execution_metadata_dict
|
||||
if execution_metadata:
|
||||
if self.node_type == NodeType.TOOL and "tool_info" in execution_metadata:
|
||||
tool_info: dict[str, Any] = execution_metadata["tool_info"]
|
||||
extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_type=tool_info["provider_type"],
|
||||
provider_id=tool_info["provider_id"],
|
||||
)
|
||||
elif self.node_type == NodeType.DATASOURCE and "datasource_info" in self.execution_metadata_dict:
|
||||
datasource_info = self.execution_metadata_dict["datasource_info"]
|
||||
elif self.node_type == NodeType.DATASOURCE and "datasource_info" in execution_metadata:
|
||||
datasource_info = execution_metadata["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
elif self.node_type == NodeType.TRIGGER_PLUGIN and "trigger_info" in execution_metadata:
|
||||
trigger_info = execution_metadata["trigger_info"] or {}
|
||||
provider_id = trigger_info.get("provider_id")
|
||||
if provider_id:
|
||||
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return extras
|
||||
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.10.1"
|
||||
version = "1.11.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
@ -151,7 +151,7 @@ dev = [
|
||||
"types-pywin32~=310.0.0",
|
||||
"types-pyyaml~=6.0.12",
|
||||
"types-regex~=2024.11.6",
|
||||
"types-shapely~=2.0.0",
|
||||
"types-shapely~=2.1.0",
|
||||
"types-simplejson>=3.20.0",
|
||||
"types-six>=1.17.0",
|
||||
"types-tensorflow>=2.18.0",
|
||||
|
||||
@ -118,7 +118,7 @@ class ConversationService:
|
||||
app_model: App,
|
||||
conversation_id: str,
|
||||
user: Union[Account, EndUser] | None,
|
||||
name: str,
|
||||
name: str | None,
|
||||
auto_generate: bool,
|
||||
):
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
@ -673,6 +673,8 @@ class DatasetService:
|
||||
Returns:
|
||||
str: Action to perform ('add', 'remove', 'update', or None)
|
||||
"""
|
||||
if "indexing_technique" not in data:
|
||||
return None
|
||||
if dataset.indexing_technique != data["indexing_technique"]:
|
||||
if data["indexing_technique"] == "economy":
|
||||
# Remove embedding model configuration for economy mode
|
||||
@ -1634,6 +1636,20 @@ class DocumentService:
|
||||
return [], ""
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.flush()
|
||||
else:
|
||||
# Fallback when no process_rule provided in knowledge_config:
|
||||
# 1) reuse dataset.latest_process_rule if present
|
||||
# 2) otherwise create an automatic rule
|
||||
dataset_process_rule = getattr(dataset, "latest_process_rule", None)
|
||||
if not dataset_process_rule:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode="automatic",
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.flush()
|
||||
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
|
||||
try:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -1645,65 +1661,67 @@ class DocumentService:
|
||||
if not knowledge_config.data_source.info_list.file_info_list:
|
||||
raise ValueError("File source info is required")
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id.in_(upload_file_list),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if len(files) != len(set(upload_file_list)):
|
||||
raise FileNotExistsError("One or more files not found.")
|
||||
|
||||
# raise error if file not found
|
||||
if not file:
|
||||
raise FileNotExistsError()
|
||||
|
||||
file_name = file.name
|
||||
file_names = [file.name for file in files]
|
||||
db_documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == "upload_file",
|
||||
Document.enabled == True,
|
||||
Document.name.in_(file_names),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
documents_map = {document.name: document for document in db_documents}
|
||||
for file in files:
|
||||
data_source_info: dict[str, str | bool] = {
|
||||
"upload_file_id": file_id,
|
||||
"upload_file_id": file.id,
|
||||
}
|
||||
# check duplicate
|
||||
if knowledge_config.duplicate:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
)
|
||||
.first()
|
||||
document = documents_map.get(file.name)
|
||||
if knowledge_config.duplicate and document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
else:
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
file.name,
|
||||
batch,
|
||||
)
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = naive_utc_now()
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
document.doc_language = knowledge_config.doc_language
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
created_from,
|
||||
position,
|
||||
account,
|
||||
file_name,
|
||||
batch,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if not notion_info_list:
|
||||
|
||||
@ -178,8 +178,8 @@ class HitTestingService:
|
||||
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
query = args["query"]
|
||||
attachment_ids = args["attachment_ids"]
|
||||
query = args.get("query")
|
||||
attachment_ids = args.get("attachment_ids")
|
||||
|
||||
if not attachment_ids and not query:
|
||||
raise ValueError("Query or attachment_ids is required")
|
||||
|
||||
@ -70,9 +70,28 @@ class ModelProviderService:
|
||||
continue
|
||||
|
||||
provider_config = provider_configuration.custom_configuration.provider
|
||||
model_config = provider_configuration.custom_configuration.models
|
||||
models = provider_configuration.custom_configuration.models
|
||||
can_added_models = provider_configuration.custom_configuration.can_added_models
|
||||
|
||||
# IMPORTANT: Never expose decrypted credentials in the provider list API.
|
||||
# Sanitize custom model configurations by dropping the credentials payload.
|
||||
sanitized_model_config = []
|
||||
if models:
|
||||
from core.entities.provider_entities import CustomModelConfiguration # local import to avoid cycles
|
||||
|
||||
for model in models:
|
||||
sanitized_model_config.append(
|
||||
CustomModelConfiguration(
|
||||
model=model.model,
|
||||
model_type=model.model_type,
|
||||
credentials=None, # strip secrets from list view
|
||||
current_credential_id=model.current_credential_id,
|
||||
current_credential_name=model.current_credential_name,
|
||||
available_model_credentials=model.available_model_credentials,
|
||||
unadded_to_model_list=model.unadded_to_model_list,
|
||||
)
|
||||
)
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_configuration.provider.provider,
|
||||
@ -95,7 +114,7 @@ class ModelProviderService:
|
||||
current_credential_id=getattr(provider_config, "current_credential_id", None),
|
||||
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
||||
available_credentials=getattr(provider_config, "available_credentials", []),
|
||||
custom_models=model_config,
|
||||
custom_models=sanitized_model_config,
|
||||
can_added_models=can_added_models,
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
|
||||
127
api/tests/fixtures/workflow/end_node_without_value_type_field_workflow.yml
vendored
Normal file
127
api/tests/fixtures/workflow/end_node_without_value_type_field_workflow.yml
vendored
Normal file
@ -0,0 +1,127 @@
|
||||
app:
|
||||
description: 'End node without value_type field reproduction'
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: workflow
|
||||
name: end_node_without_value_type_field_reproduction
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies: []
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
fileUploadConfig:
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
image_file_batch_limit: 10
|
||||
image_file_size_limit: 10
|
||||
single_chunk_attachment_limit: 10
|
||||
video_file_size_limit: 100
|
||||
workflow_file_upload_limit: 10
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: start
|
||||
targetType: end
|
||||
id: 1765423445456-source-1765423454810-target
|
||||
source: '1765423445456'
|
||||
sourceHandle: source
|
||||
target: '1765423454810'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: 用户输入
|
||||
type: start
|
||||
variables:
|
||||
- default: ''
|
||||
hint: ''
|
||||
label: query
|
||||
max_length: 48
|
||||
options: []
|
||||
placeholder: ''
|
||||
required: true
|
||||
type: text-input
|
||||
variable: query
|
||||
height: 109
|
||||
id: '1765423445456'
|
||||
position:
|
||||
x: -48
|
||||
y: 261
|
||||
positionAbsolute:
|
||||
x: -48
|
||||
y: 261
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
outputs:
|
||||
- value_selector:
|
||||
- '1765423445456'
|
||||
- query
|
||||
variable: query
|
||||
selected: true
|
||||
title: 输出
|
||||
type: end
|
||||
height: 88
|
||||
id: '1765423454810'
|
||||
position:
|
||||
x: 382
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 382
|
||||
y: 282
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: 139
|
||||
y: -135
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
@ -0,0 +1 @@
|
||||
|
||||
182
api/tests/test_containers_integration_tests/trigger/conftest.py
Normal file
182
api/tests/test_containers_integration_tests/trigger/conftest.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""
|
||||
Fixtures for trigger integration tests.
|
||||
|
||||
This module provides fixtures for creating test data (tenant, account, app)
|
||||
and mock objects used across trigger-related tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import App
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[Tenant, Account], None, None]:
|
||||
"""
|
||||
Create a tenant and account for testing.
|
||||
|
||||
This fixture creates a tenant, account, and their association,
|
||||
then cleans up after the test completes.
|
||||
|
||||
Yields:
|
||||
tuple[Tenant, Account]: The created tenant and account
|
||||
"""
|
||||
tenant = Tenant(name="trigger-e2e")
|
||||
account = Account(name="tester", email="tester@example.com", interface_language="en-US")
|
||||
db_session_with_containers.add_all([tenant, account])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER.value)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
yield tenant, account
|
||||
|
||||
# Cleanup
|
||||
db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
||||
db_session_with_containers.query(Account).filter_by(id=account.id).delete()
|
||||
db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_model(
|
||||
db_session_with_containers: Session, tenant_and_account: tuple[Tenant, Account]
|
||||
) -> Generator[App, None, None]:
|
||||
"""
|
||||
Create an app for testing.
|
||||
|
||||
This fixture creates a workflow app associated with the tenant and account,
|
||||
then cleans up after the test completes.
|
||||
|
||||
Yields:
|
||||
App: The created app
|
||||
"""
|
||||
tenant, account = tenant_and_account
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="trigger-app",
|
||||
description="trigger e2e",
|
||||
mode="workflow",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#FFEAD5",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=1000,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
yield app
|
||||
|
||||
# Cleanup - delete related records first
|
||||
from models.trigger import (
|
||||
AppTrigger,
|
||||
TriggerSubscription,
|
||||
WorkflowPluginTrigger,
|
||||
WorkflowSchedulePlan,
|
||||
WorkflowTriggerLog,
|
||||
WorkflowWebhookTrigger,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
|
||||
db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete()
|
||||
db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete()
|
||||
db_session_with_containers.query(App).filter_by(id=app.id).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
class MockCeleryGroup:
|
||||
"""Mock for celery group() function that collects dispatched tasks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.collected: list[dict[str, Any]] = []
|
||||
self._applied = False
|
||||
|
||||
def __call__(self, items: Any) -> MockCeleryGroup:
|
||||
self.collected = list(items)
|
||||
return self
|
||||
|
||||
def apply_async(self) -> None:
|
||||
self._applied = True
|
||||
|
||||
@property
|
||||
def applied(self) -> bool:
|
||||
return self._applied
|
||||
|
||||
|
||||
class MockCelerySignature:
|
||||
"""Mock for celery task signature that returns task info dict."""
|
||||
|
||||
def s(self, schedule_id: str) -> dict[str, str]:
|
||||
return {"schedule_id": schedule_id}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_celery_group() -> MockCeleryGroup:
|
||||
"""
|
||||
Provide a mock celery group for testing task dispatch.
|
||||
|
||||
Returns:
|
||||
MockCeleryGroup: Mock group that collects dispatched tasks
|
||||
"""
|
||||
return MockCeleryGroup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_celery_signature() -> MockCelerySignature:
|
||||
"""
|
||||
Provide a mock celery signature for testing task dispatch.
|
||||
|
||||
Returns:
|
||||
MockCelerySignature: Mock signature generator
|
||||
"""
|
||||
return MockCelerySignature()
|
||||
|
||||
|
||||
class MockPluginSubscription:
|
||||
"""Mock plugin subscription for testing plugin triggers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subscription_id: str = "sub-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
provider_id: str = "provider-1",
|
||||
) -> None:
|
||||
self.id = subscription_id
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_id = provider_id
|
||||
self.credentials: dict[str, str] = {"token": "secret"}
|
||||
self.credential_type = "api-key"
|
||||
|
||||
def to_entity(self) -> MockPluginSubscription:
|
||||
return self
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_plugin_subscription() -> MockPluginSubscription:
|
||||
"""
|
||||
Provide a mock plugin subscription for testing.
|
||||
|
||||
Returns:
|
||||
MockPluginSubscription: Mock subscription instance
|
||||
"""
|
||||
return MockPluginSubscription()
|
||||
@ -0,0 +1,911 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from flask import Flask, Response
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
from core.trigger.debug import event_selectors
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller
|
||||
from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key
|
||||
from core.workflow.enums import NodeType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account, Tenant
|
||||
from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.model import App
|
||||
from models.trigger import (
|
||||
AppTrigger,
|
||||
TriggerSubscription,
|
||||
WorkflowPluginTrigger,
|
||||
WorkflowSchedulePlan,
|
||||
WorkflowTriggerLog,
|
||||
WorkflowWebhookTrigger,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
from schedule import workflow_schedule_task
|
||||
from schedule.workflow_schedule_task import poll_workflow_schedules
|
||||
from services import feature_service as feature_service_module
|
||||
from services.trigger import webhook_service
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks import trigger_processing_tasks
|
||||
|
||||
from .conftest import MockCeleryGroup, MockCelerySignature, MockPluginSubscription
|
||||
|
||||
# Test constants
|
||||
WEBHOOK_ID_PRODUCTION = "wh1234567890123456789012"
|
||||
WEBHOOK_ID_DEBUG = "whdebug1234567890123456"
|
||||
TEST_TRIGGER_URL = "https://trigger.example.com/base"
|
||||
|
||||
|
||||
def _build_workflow_graph(root_node_id: str, trigger_type: NodeType) -> str:
|
||||
"""Build a minimal workflow graph JSON for testing."""
|
||||
node_data: dict[str, Any] = {"type": trigger_type.value, "title": "trigger"}
|
||||
if trigger_type == NodeType.TRIGGER_WEBHOOK:
|
||||
node_data.update(
|
||||
{
|
||||
"method": "POST",
|
||||
"content_type": "application/json",
|
||||
"headers": [],
|
||||
"params": [],
|
||||
"body": [],
|
||||
}
|
||||
)
|
||||
graph = {
|
||||
"nodes": [
|
||||
{"id": root_node_id, "data": node_data},
|
||||
{"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}},
|
||||
],
|
||||
"edges": [{"source": root_node_id, "target": "answer-1", "sourceHandle": "success"}],
|
||||
}
|
||||
return json.dumps(graph)
|
||||
|
||||
|
||||
def test_publish_blocks_start_and_trigger_coexistence(
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Publishing should fail when both start and trigger nodes coexist."""
|
||||
tenant, account = tenant_and_account
|
||||
|
||||
graph = {
|
||||
"nodes": [
|
||||
{"id": "start", "data": {"type": NodeType.START.value}},
|
||||
{"id": "trig", "data": {"type": NodeType.TRIGGER_WEBHOOK.value}},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
draft_workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
type="workflow",
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(draft_workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
monkeypatch.setattr(
|
||||
feature_service_module.FeatureService,
|
||||
"get_system_features",
|
||||
classmethod(lambda _cls: SimpleNamespace(plugin_manager=SimpleNamespace(enabled=False))),
|
||||
)
|
||||
monkeypatch.setattr("services.workflow_service.dify_config", SimpleNamespace(BILLING_ENABLED=False))
|
||||
|
||||
with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"):
|
||||
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app_model, account=account)
|
||||
|
||||
|
||||
def test_trigger_url_uses_config_base(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""TRIGGER_URL config should be reflected in generated webhook and plugin endpoints."""
|
||||
original_url = getattr(dify_config, "TRIGGER_URL", None)
|
||||
|
||||
try:
|
||||
monkeypatch.setattr(dify_config, "TRIGGER_URL", TEST_TRIGGER_URL)
|
||||
endpoint_module = importlib.reload(importlib.import_module("core.trigger.utils.endpoint"))
|
||||
|
||||
assert (
|
||||
endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION)
|
||||
== f"{TEST_TRIGGER_URL}/triggers/webhook/{WEBHOOK_ID_PRODUCTION}"
|
||||
)
|
||||
assert (
|
||||
endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION, True)
|
||||
== f"{TEST_TRIGGER_URL}/triggers/webhook-debug/{WEBHOOK_ID_PRODUCTION}"
|
||||
)
|
||||
assert (
|
||||
endpoint_module.generate_plugin_trigger_endpoint_url("end-1") == f"{TEST_TRIGGER_URL}/triggers/plugin/end-1"
|
||||
)
|
||||
finally:
|
||||
# Restore original config and reload module
|
||||
if original_url is not None:
|
||||
monkeypatch.setattr(dify_config, "TRIGGER_URL", original_url)
|
||||
importlib.reload(importlib.import_module("core.trigger.utils.endpoint"))
|
||||
|
||||
|
||||
def test_webhook_trigger_creates_trigger_log(
|
||||
test_client_with_containers: FlaskClient,
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Production webhook trigger should create a trigger log in the database."""
|
||||
tenant, account = tenant_and_account
|
||||
|
||||
webhook_node_id = "webhook-node"
|
||||
graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK)
|
||||
published_workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
type="workflow",
|
||||
version=Workflow.version_from_datetime(naive_utc_now()),
|
||||
graph=graph_json,
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(published_workflow)
|
||||
app_model.workflow_id = published_workflow.id
|
||||
db_session_with_containers.commit()
|
||||
|
||||
webhook_trigger = WorkflowWebhookTrigger(
|
||||
app_id=app_model.id,
|
||||
node_id=webhook_node_id,
|
||||
tenant_id=tenant.id,
|
||||
webhook_id=WEBHOOK_ID_PRODUCTION,
|
||||
created_by=account.id,
|
||||
)
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
node_id=webhook_node_id,
|
||||
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
|
||||
status=AppTriggerStatus.ENABLED,
|
||||
title="webhook",
|
||||
)
|
||||
|
||||
db_session_with_containers.add_all([webhook_trigger, app_trigger])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace:
|
||||
log = WorkflowTriggerLog(
|
||||
tenant_id=trigger_data.tenant_id,
|
||||
app_id=trigger_data.app_id,
|
||||
workflow_id=trigger_data.workflow_id,
|
||||
root_node_id=trigger_data.root_node_id,
|
||||
trigger_metadata=trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}",
|
||||
trigger_type=trigger_data.trigger_type,
|
||||
workflow_run_id=None,
|
||||
outputs=None,
|
||||
trigger_data=trigger_data.model_dump_json(),
|
||||
inputs=json.dumps(dict(trigger_data.inputs)),
|
||||
status=WorkflowTriggerStatus.SUCCEEDED,
|
||||
error="",
|
||||
queue_name="triggered_workflow_dispatcher",
|
||||
celery_task_id="celery-test",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
session.add(log)
|
||||
session.commit()
|
||||
return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test")
|
||||
|
||||
monkeypatch.setattr(
|
||||
webhook_service.AsyncWorkflowService,
|
||||
"trigger_workflow_async",
|
||||
_fake_trigger_workflow_async,
|
||||
)
|
||||
|
||||
response = test_client_with_containers.post(f"/triggers/webhook/{webhook_trigger.webhook_id}", json={"foo": "bar"})
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
|
||||
assert logs, "Webhook trigger should create trigger log"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schedule_type", ["visual", "cron"])
|
||||
def test_schedule_poll_dispatches_due_plan(
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
mock_celery_group: MockCeleryGroup,
|
||||
mock_celery_signature: MockCelerySignature,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
schedule_type: str,
|
||||
) -> None:
|
||||
"""Schedule plans (both visual and cron) should be polled and dispatched when due."""
|
||||
tenant, _ = tenant_and_account
|
||||
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
node_id=f"schedule-{schedule_type}",
|
||||
trigger_type=AppTriggerType.TRIGGER_SCHEDULE,
|
||||
status=AppTriggerStatus.ENABLED,
|
||||
title=f"schedule-{schedule_type}",
|
||||
)
|
||||
plan = WorkflowSchedulePlan(
|
||||
app_id=app_model.id,
|
||||
node_id=f"schedule-{schedule_type}",
|
||||
tenant_id=tenant.id,
|
||||
cron_expression="* * * * *",
|
||||
timezone="UTC",
|
||||
next_run_at=naive_utc_now() - timedelta(minutes=1),
|
||||
)
|
||||
db_session_with_containers.add_all([app_trigger, plan])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
next_time = naive_utc_now() + timedelta(hours=1)
|
||||
monkeypatch.setattr(workflow_schedule_task, "calculate_next_run_at", lambda *_args, **_kwargs: next_time)
|
||||
monkeypatch.setattr(workflow_schedule_task, "group", mock_celery_group)
|
||||
monkeypatch.setattr(workflow_schedule_task, "run_schedule_trigger", mock_celery_signature)
|
||||
|
||||
poll_workflow_schedules()
|
||||
|
||||
assert mock_celery_group.collected, f"Should dispatch signatures for due {schedule_type} schedules"
|
||||
scheduled_ids = {sig["schedule_id"] for sig in mock_celery_group.collected}
|
||||
assert plan.id in scheduled_ids
|
||||
|
||||
|
||||
def test_schedule_visual_debug_poll_generates_event(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Visual mode schedule node should generate event in single-step debug."""
|
||||
base_now = naive_utc_now()
|
||||
monkeypatch.setattr(event_selectors, "naive_utc_now", lambda: base_now)
|
||||
monkeypatch.setattr(
|
||||
event_selectors,
|
||||
"calculate_next_run_at",
|
||||
lambda *_args, **_kwargs: base_now - timedelta(minutes=1),
|
||||
)
|
||||
node_config = {
|
||||
"id": "schedule-visual",
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_SCHEDULE.value,
|
||||
"mode": "visual",
|
||||
"frequency": "daily",
|
||||
"visual_config": {"time": "3:00 PM"},
|
||||
"timezone": "UTC",
|
||||
},
|
||||
}
|
||||
poller = event_selectors.ScheduleTriggerDebugEventPoller(
|
||||
tenant_id="tenant",
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
node_config=node_config,
|
||||
node_id="schedule-visual",
|
||||
)
|
||||
event = poller.poll()
|
||||
assert event is not None
|
||||
assert event.workflow_args["inputs"] == {}
|
||||
|
||||
|
||||
def test_plugin_trigger_dispatches_and_debug_events(
|
||||
test_client_with_containers: FlaskClient,
|
||||
mock_plugin_subscription: MockPluginSubscription,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Plugin trigger endpoint should dispatch events and generate debug events."""
|
||||
endpoint_id = "1cc7fa12-3f7b-4f6a-9c8d-1234567890ab"
|
||||
|
||||
debug_events: list[dict[str, Any]] = []
|
||||
dispatched_payloads: list[dict[str, Any]] = []
|
||||
|
||||
def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response:
|
||||
dispatch_data = {
|
||||
"user_id": "end-user",
|
||||
"tenant_id": mock_plugin_subscription.tenant_id,
|
||||
"endpoint_id": _endpoint_id,
|
||||
"provider_id": mock_plugin_subscription.provider_id,
|
||||
"subscription_id": mock_plugin_subscription.id,
|
||||
"timestamp": int(time.time()),
|
||||
"events": ["created", "updated"],
|
||||
"request_id": f"req-{_endpoint_id}",
|
||||
}
|
||||
trigger_processing_tasks.dispatch_triggered_workflows_async.delay(dispatch_data)
|
||||
return Response("ok", status=202)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.trigger_service.TriggerService.process_endpoint",
|
||||
staticmethod(_fake_process_endpoint),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
trigger_processing_tasks.TriggerDebugEventBus,
|
||||
"dispatch",
|
||||
staticmethod(lambda **kwargs: debug_events.append(kwargs) or 1),
|
||||
)
|
||||
|
||||
def _fake_delay(dispatch_data: dict[str, Any]) -> None:
|
||||
dispatched_payloads.append(dispatch_data)
|
||||
trigger_processing_tasks.dispatch_trigger_debug_event(
|
||||
events=dispatch_data["events"],
|
||||
user_id=dispatch_data["user_id"],
|
||||
timestamp=dispatch_data["timestamp"],
|
||||
request_id=dispatch_data["request_id"],
|
||||
subscription=mock_plugin_subscription,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
trigger_processing_tasks.dispatch_triggered_workflows_async,
|
||||
"delay",
|
||||
staticmethod(_fake_delay),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"hello": "world"})
|
||||
|
||||
assert response.status_code == 202
|
||||
assert dispatched_payloads, "Plugin trigger should enqueue workflow dispatch payload"
|
||||
assert debug_events, "Plugin trigger should dispatch debug events"
|
||||
dispatched_event_names = {event["event"].name for event in debug_events}
|
||||
assert dispatched_event_names == {"created", "updated"}
|
||||
|
||||
|
||||
def test_webhook_debug_dispatches_event(
|
||||
test_client_with_containers: FlaskClient,
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Webhook single-step debug should dispatch debug event and be pollable."""
|
||||
tenant, account = tenant_and_account
|
||||
webhook_node_id = "webhook-debug-node"
|
||||
graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK)
|
||||
draft_workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
type="workflow",
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=graph_json,
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(draft_workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
webhook_trigger = WorkflowWebhookTrigger(
|
||||
app_id=app_model.id,
|
||||
node_id=webhook_node_id,
|
||||
tenant_id=tenant.id,
|
||||
webhook_id=WEBHOOK_ID_DEBUG,
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(webhook_trigger)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
debug_events: list[dict[str, Any]] = []
|
||||
original_dispatch = TriggerDebugEventBus.dispatch
|
||||
monkeypatch.setattr(
|
||||
"controllers.trigger.webhook.TriggerDebugEventBus.dispatch",
|
||||
lambda **kwargs: (debug_events.append(kwargs), original_dispatch(**kwargs))[1],
|
||||
)
|
||||
|
||||
# Listener polls first to enter waiting pool
|
||||
poller = WebhookTriggerDebugEventPoller(
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
app_id=app_model.id,
|
||||
node_config=draft_workflow.get_node_config_by_id(webhook_node_id),
|
||||
node_id=webhook_node_id,
|
||||
)
|
||||
assert poller.poll() is None
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
f"/triggers/webhook-debug/{webhook_trigger.webhook_id}",
|
||||
json={"foo": "bar"},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert debug_events, "Debug event should be sent to event bus"
|
||||
# Second poll should get the event
|
||||
event = poller.poll()
|
||||
assert event is not None
|
||||
assert event.workflow_args["inputs"]["webhook_body"]["foo"] == "bar"
|
||||
assert debug_events[0]["pool_key"].endswith(f":{app_model.id}:{webhook_node_id}")
|
||||
|
||||
|
||||
def test_plugin_single_step_debug_flow(
|
||||
flask_app_with_containers: Flask,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Plugin single-step debug: listen -> dispatch event -> poller receives and returns variables."""
|
||||
tenant_id = "tenant-1"
|
||||
app_id = "app-1"
|
||||
user_id = "user-1"
|
||||
node_id = "plugin-node"
|
||||
provider_id = "langgenius/provider-1/provider-1"
|
||||
node_config = {
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_PLUGIN.value,
|
||||
"title": "plugin",
|
||||
"plugin_id": "plugin-1",
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
"provider_id": provider_id,
|
||||
"event_name": "created",
|
||||
"subscription_id": "sub-1",
|
||||
"parameters": {},
|
||||
},
|
||||
}
|
||||
# Start listening
|
||||
poller = PluginTriggerDebugEventPoller(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
node_config=node_config,
|
||||
node_id=node_id,
|
||||
)
|
||||
assert poller.poll() is None
|
||||
|
||||
from core.trigger.debug.events import build_plugin_pool_key
|
||||
|
||||
pool_key = build_plugin_pool_key(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
subscription_id="sub-1",
|
||||
name="created",
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
tenant_id=tenant_id,
|
||||
event=PluginTriggerDebugEvent(
|
||||
timestamp=int(time.time()),
|
||||
user_id=user_id,
|
||||
name="created",
|
||||
request_id="req-1",
|
||||
subscription_id="sub-1",
|
||||
provider_id="provider-1",
|
||||
),
|
||||
pool_key=pool_key,
|
||||
)
|
||||
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.trigger_service.TriggerService.invoke_trigger_event",
|
||||
staticmethod(
|
||||
lambda **_kwargs: TriggerInvokeEventResponse(
|
||||
variables={"echo": "pong"},
|
||||
cancelled=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
event = poller.poll()
|
||||
assert event is not None
|
||||
assert event.workflow_args["inputs"]["echo"] == "pong"
|
||||
|
||||
|
||||
def test_schedule_trigger_creates_trigger_log(
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Schedule trigger execution should create WorkflowTriggerLog in database."""
|
||||
from tasks import workflow_schedule_tasks
|
||||
|
||||
tenant, account = tenant_and_account
|
||||
|
||||
# Create published workflow with schedule trigger node
|
||||
schedule_node_id = "schedule-node"
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": schedule_node_id,
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_SCHEDULE.value,
|
||||
"title": "schedule",
|
||||
"mode": "cron",
|
||||
"cron_expression": "0 9 * * *",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
},
|
||||
{"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}},
|
||||
],
|
||||
"edges": [{"source": schedule_node_id, "target": "answer-1", "sourceHandle": "success"}],
|
||||
}
|
||||
published_workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
type="workflow",
|
||||
version=Workflow.version_from_datetime(naive_utc_now()),
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(published_workflow)
|
||||
app_model.workflow_id = published_workflow.id
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create schedule plan
|
||||
plan = WorkflowSchedulePlan(
|
||||
app_id=app_model.id,
|
||||
node_id=schedule_node_id,
|
||||
tenant_id=tenant.id,
|
||||
cron_expression="0 9 * * *",
|
||||
timezone="UTC",
|
||||
next_run_at=naive_utc_now() - timedelta(minutes=1),
|
||||
)
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
node_id=schedule_node_id,
|
||||
trigger_type=AppTriggerType.TRIGGER_SCHEDULE,
|
||||
status=AppTriggerStatus.ENABLED,
|
||||
title="schedule",
|
||||
)
|
||||
db_session_with_containers.add_all([plan, app_trigger])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock AsyncWorkflowService to create WorkflowTriggerLog
|
||||
def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace:
|
||||
log = WorkflowTriggerLog(
|
||||
tenant_id=trigger_data.tenant_id,
|
||||
app_id=trigger_data.app_id,
|
||||
workflow_id=published_workflow.id,
|
||||
root_node_id=trigger_data.root_node_id,
|
||||
trigger_metadata="{}",
|
||||
trigger_type=AppTriggerType.TRIGGER_SCHEDULE,
|
||||
workflow_run_id=None,
|
||||
outputs=None,
|
||||
trigger_data=trigger_data.model_dump_json(),
|
||||
inputs=json.dumps(dict(trigger_data.inputs)),
|
||||
status=WorkflowTriggerStatus.SUCCEEDED,
|
||||
error="",
|
||||
queue_name="schedule_executor",
|
||||
celery_task_id="celery-schedule-test",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
session.add(log)
|
||||
session.commit()
|
||||
return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test")
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_schedule_tasks.AsyncWorkflowService,
|
||||
"trigger_workflow_async",
|
||||
_fake_trigger_workflow_async,
|
||||
)
|
||||
|
||||
# Mock quota to avoid rate limiting
|
||||
from enums import quota_type
|
||||
|
||||
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
|
||||
|
||||
# Execute schedule trigger
|
||||
workflow_schedule_tasks.run_schedule_trigger(plan.id)
|
||||
|
||||
# Verify WorkflowTriggerLog was created
|
||||
db_session_with_containers.expire_all()
|
||||
logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all()
|
||||
assert logs, "Schedule trigger should create WorkflowTriggerLog"
|
||||
assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE
|
||||
assert logs[0].root_node_id == schedule_node_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "frequency", "visual_config", "cron_expression", "expected_cron"),
|
||||
[
|
||||
# Visual mode: hourly
|
||||
("visual", "hourly", {"on_minute": 30}, None, "30 * * * *"),
|
||||
# Visual mode: daily
|
||||
("visual", "daily", {"time": "3:00 PM"}, None, "0 15 * * *"),
|
||||
# Visual mode: weekly
|
||||
("visual", "weekly", {"time": "9:00 AM", "weekdays": ["mon", "wed", "fri"]}, None, "0 9 * * 1,3,5"),
|
||||
# Visual mode: monthly
|
||||
("visual", "monthly", {"time": "10:30 AM", "monthly_days": [1, 15]}, None, "30 10 1,15 * *"),
|
||||
# Cron mode: direct expression
|
||||
("cron", None, None, "*/5 * * * *", "*/5 * * * *"),
|
||||
],
|
||||
)
|
||||
def test_schedule_visual_cron_conversion(
|
||||
mode: str,
|
||||
frequency: str | None,
|
||||
visual_config: dict[str, Any] | None,
|
||||
cron_expression: str | None,
|
||||
expected_cron: str,
|
||||
) -> None:
|
||||
"""Schedule visual config should correctly convert to cron expression."""
|
||||
|
||||
node_config: dict[str, Any] = {
|
||||
"id": "schedule-node",
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_SCHEDULE.value,
|
||||
"mode": mode,
|
||||
"timezone": "UTC",
|
||||
},
|
||||
}
|
||||
|
||||
if mode == "visual":
|
||||
node_config["data"]["frequency"] = frequency
|
||||
node_config["data"]["visual_config"] = visual_config
|
||||
else:
|
||||
node_config["data"]["cron_expression"] = cron_expression
|
||||
|
||||
config = ScheduleService.to_schedule_config(node_config)
|
||||
|
||||
assert config.cron_expression == expected_cron, f"Expected {expected_cron}, got {config.cron_expression}"
|
||||
assert config.timezone == "UTC"
|
||||
assert config.node_id == "schedule-node"
|
||||
|
||||
|
||||
def test_plugin_trigger_full_chain_with_db_verification(
|
||||
test_client_with_containers: FlaskClient,
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Plugin trigger should create WorkflowTriggerLog and WorkflowPluginTrigger records."""
|
||||
|
||||
tenant, account = tenant_and_account
|
||||
|
||||
# Create published workflow with plugin trigger node
|
||||
plugin_node_id = "plugin-trigger-node"
|
||||
provider_id = "langgenius/test-provider/test-provider"
|
||||
subscription_id = "sub-plugin-test"
|
||||
endpoint_id = "2cc7fa12-3f7b-4f6a-9c8d-1234567890ab"
|
||||
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": plugin_node_id,
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_PLUGIN.value,
|
||||
"title": "plugin",
|
||||
"plugin_id": "test-plugin",
|
||||
"plugin_unique_identifier": "test-plugin",
|
||||
"provider_id": provider_id,
|
||||
"event_name": "test_event",
|
||||
"subscription_id": subscription_id,
|
||||
"parameters": {},
|
||||
},
|
||||
},
|
||||
{"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}},
|
||||
],
|
||||
"edges": [{"source": plugin_node_id, "target": "answer-1", "sourceHandle": "success"}],
|
||||
}
|
||||
published_workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
type="workflow",
|
||||
version=Workflow.version_from_datetime(naive_utc_now()),
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db_session_with_containers.add(published_workflow)
|
||||
app_model.workflow_id = published_workflow.id
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create trigger subscription
|
||||
subscription = TriggerSubscription(
|
||||
name="test-subscription",
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
credentials={"token": "test-secret"},
|
||||
credential_type="api-key",
|
||||
)
|
||||
db_session_with_containers.add(subscription)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Update subscription_id to match the created subscription
|
||||
graph["nodes"][0]["data"]["subscription_id"] = subscription.id
|
||||
published_workflow.graph = json.dumps(graph)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create WorkflowPluginTrigger
|
||||
plugin_trigger = WorkflowPluginTrigger(
|
||||
app_id=app_model.id,
|
||||
tenant_id=tenant.id,
|
||||
node_id=plugin_node_id,
|
||||
provider_id=provider_id,
|
||||
event_name="test_event",
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
app_trigger = AppTrigger(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_model.id,
|
||||
node_id=plugin_node_id,
|
||||
trigger_type=AppTriggerType.TRIGGER_PLUGIN,
|
||||
status=AppTriggerStatus.ENABLED,
|
||||
title="plugin",
|
||||
)
|
||||
db_session_with_containers.add_all([plugin_trigger, app_trigger])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Track dispatched data
|
||||
dispatched_data: list[dict[str, Any]] = []
|
||||
|
||||
def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response:
|
||||
dispatch_data = {
|
||||
"user_id": "end-user",
|
||||
"tenant_id": tenant.id,
|
||||
"endpoint_id": _endpoint_id,
|
||||
"provider_id": provider_id,
|
||||
"subscription_id": subscription.id,
|
||||
"timestamp": int(time.time()),
|
||||
"events": ["test_event"],
|
||||
"request_id": f"req-{_endpoint_id}",
|
||||
}
|
||||
dispatched_data.append(dispatch_data)
|
||||
return Response("ok", status=202)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.trigger_service.TriggerService.process_endpoint",
|
||||
staticmethod(_fake_process_endpoint),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"test": "data"})
|
||||
|
||||
assert response.status_code == 202
|
||||
assert dispatched_data, "Plugin trigger should dispatch event data"
|
||||
assert dispatched_data[0]["subscription_id"] == subscription.id
|
||||
assert dispatched_data[0]["events"] == ["test_event"]
|
||||
|
||||
# Verify database records exist
|
||||
db_session_with_containers.expire_all()
|
||||
plugin_triggers = (
|
||||
db_session_with_containers.query(WorkflowPluginTrigger)
|
||||
.filter_by(app_id=app_model.id, node_id=plugin_node_id)
|
||||
.all()
|
||||
)
|
||||
assert plugin_triggers, "WorkflowPluginTrigger record should exist"
|
||||
assert plugin_triggers[0].provider_id == provider_id
|
||||
assert plugin_triggers[0].event_name == "test_event"
|
||||
|
||||
|
||||
def test_plugin_debug_via_http_endpoint(
|
||||
test_client_with_containers: FlaskClient,
|
||||
db_session_with_containers: Session,
|
||||
tenant_and_account: tuple[Tenant, Account],
|
||||
app_model: App,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Plugin single-step debug via HTTP endpoint should dispatch debug event and be pollable."""
|
||||
|
||||
tenant, account = tenant_and_account
|
||||
|
||||
provider_id = "langgenius/debug-provider/debug-provider"
|
||||
endpoint_id = "3cc7fa12-3f7b-4f6a-9c8d-1234567890ab"
|
||||
event_name = "debug_event"
|
||||
|
||||
# Create subscription
|
||||
subscription = TriggerSubscription(
|
||||
name="debug-subscription",
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=endpoint_id,
|
||||
parameters={},
|
||||
properties={},
|
||||
credentials={"token": "debug-secret"},
|
||||
credential_type="api-key",
|
||||
)
|
||||
db_session_with_containers.add(subscription)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create plugin trigger node config
|
||||
node_id = "plugin-debug-node"
|
||||
node_config = {
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"type": NodeType.TRIGGER_PLUGIN.value,
|
||||
"title": "plugin-debug",
|
||||
"plugin_id": "debug-plugin",
|
||||
"plugin_unique_identifier": "debug-plugin",
|
||||
"provider_id": provider_id,
|
||||
"event_name": event_name,
|
||||
"subscription_id": subscription.id,
|
||||
"parameters": {},
|
||||
},
|
||||
}
|
||||
|
||||
# Start listening with poller
|
||||
|
||||
poller = PluginTriggerDebugEventPoller(
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
app_id=app_model.id,
|
||||
node_config=node_config,
|
||||
node_id=node_id,
|
||||
)
|
||||
assert poller.poll() is None, "First poll should return None (waiting)"
|
||||
|
||||
# Track debug events dispatched
|
||||
debug_events: list[dict[str, Any]] = []
|
||||
original_dispatch = TriggerDebugEventBus.dispatch
|
||||
|
||||
def _tracking_dispatch(**kwargs: Any) -> int:
|
||||
debug_events.append(kwargs)
|
||||
return original_dispatch(**kwargs)
|
||||
|
||||
monkeypatch.setattr(TriggerDebugEventBus, "dispatch", staticmethod(_tracking_dispatch))
|
||||
|
||||
# Mock process_endpoint to trigger debug event dispatch
|
||||
def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response:
|
||||
# Simulate what happens inside process_endpoint + dispatch_triggered_workflows_async
|
||||
pool_key = build_plugin_pool_key(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription.id,
|
||||
name=event_name,
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
tenant_id=tenant.id,
|
||||
event=PluginTriggerDebugEvent(
|
||||
timestamp=int(time.time()),
|
||||
user_id="end-user",
|
||||
name=event_name,
|
||||
request_id=f"req-{_endpoint_id}",
|
||||
subscription_id=subscription.id,
|
||||
provider_id=provider_id,
|
||||
),
|
||||
pool_key=pool_key,
|
||||
)
|
||||
return Response("ok", status=202)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.trigger_service.TriggerService.process_endpoint",
|
||||
staticmethod(_fake_process_endpoint),
|
||||
)
|
||||
|
||||
# Call HTTP endpoint
|
||||
response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"debug": "payload"})
|
||||
|
||||
assert response.status_code == 202
|
||||
assert debug_events, "Debug event should be dispatched via HTTP endpoint"
|
||||
assert debug_events[0]["event"].name == event_name
|
||||
|
||||
# Mock invoke_trigger_event for poller
|
||||
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.trigger_service.TriggerService.invoke_trigger_event",
|
||||
staticmethod(
|
||||
lambda **_kwargs: TriggerInvokeEventResponse(
|
||||
variables={"http_debug": "success"},
|
||||
cancelled=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Second poll should receive the event
|
||||
event = poller.poll()
|
||||
assert event is not None, "Poller should receive debug event after HTTP trigger"
|
||||
assert event.workflow_args["inputs"]["http_debug"] == "success"
|
||||
@ -0,0 +1,25 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.service_api.app.completion import ChatRequestPayload
|
||||
|
||||
|
||||
def test_chat_request_payload_accepts_blank_conversation_id():
|
||||
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": ""})
|
||||
|
||||
assert payload.conversation_id is None
|
||||
|
||||
|
||||
def test_chat_request_payload_validates_uuid():
|
||||
conversation_id = str(uuid.uuid4())
|
||||
|
||||
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": conversation_id})
|
||||
|
||||
assert payload.conversation_id == conversation_id
|
||||
|
||||
|
||||
def test_chat_request_payload_rejects_invalid_uuid():
|
||||
with pytest.raises(ValidationError):
|
||||
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"})
|
||||
@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.explore.conversation import ConversationRenamePayload as ConsolePayload
|
||||
from controllers.service_api.app.conversation import ConversationRenamePayload as ServicePayload
|
||||
|
||||
|
||||
@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload])
|
||||
def test_payload_allows_auto_generate_without_name(payload_cls):
|
||||
payload = payload_cls.model_validate({"auto_generate": True})
|
||||
|
||||
assert payload.auto_generate is True
|
||||
assert payload.name is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload])
|
||||
@pytest.mark.parametrize("value", [None, "", " "])
|
||||
def test_payload_requires_name_when_not_auto_generate(payload_cls, value):
|
||||
with pytest.raises(ValidationError):
|
||||
payload_cls.model_validate({"name": value, "auto_generate": False})
|
||||
@ -0,0 +1,60 @@
|
||||
"""
|
||||
Test case for end node without value_type field (backward compatibility).
|
||||
|
||||
This test validates that end nodes work correctly even when the value_type
|
||||
field is missing from the output configuration, ensuring backward compatibility
|
||||
with older workflow definitions.
|
||||
"""
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def test_end_node_without_value_type_field():
|
||||
"""
|
||||
Test that end node works without explicit value_type field.
|
||||
|
||||
The fixture implements a simple workflow that:
|
||||
1. Takes a query input from start node
|
||||
2. Passes it directly to end node
|
||||
3. End node outputs the value without specifying value_type
|
||||
4. Should correctly infer the type and output the value
|
||||
|
||||
This ensures backward compatibility with workflow definitions
|
||||
created before value_type became a required field.
|
||||
"""
|
||||
fixture_name = "end_node_without_value_type_field_workflow"
|
||||
|
||||
case = WorkflowTestCase(
|
||||
fixture_path=fixture_name,
|
||||
inputs={"query": "test query"},
|
||||
expected_outputs={"query": "test query"},
|
||||
expected_event_sequence=[
|
||||
# Graph start
|
||||
GraphRunStartedEvent,
|
||||
# Start node
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent, # Start node streams the input value
|
||||
NodeRunSucceededEvent,
|
||||
# End node
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
# Graph end
|
||||
GraphRunSucceededEvent,
|
||||
],
|
||||
description="End node without value_type field should work correctly",
|
||||
)
|
||||
|
||||
runner = TableTestRunner()
|
||||
result = runner.run_test_case(case)
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
assert result.actual_outputs == {"query": "test query"}, (
|
||||
f"Expected output to be {{'query': 'test query'}}, got {result.actual_outputs}"
|
||||
)
|
||||
@ -1149,3 +1149,258 @@ class TestModelIntegration:
|
||||
# Assert
|
||||
assert site.app_id == app.id
|
||||
assert app.enable_site is True
|
||||
|
||||
|
||||
class TestConversationStatusCount:
|
||||
"""Test suite for Conversation.status_count property N+1 query fix."""
|
||||
|
||||
def test_status_count_no_messages(self):
|
||||
"""Test status_count returns None when conversation has no messages."""
|
||||
# Arrange
|
||||
conversation = Conversation(
|
||||
app_id=str(uuid4()),
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = str(uuid4())
|
||||
|
||||
# Mock the database query to return no messages
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_status_count_messages_without_workflow_runs(self):
|
||||
"""Test status_count when messages have no workflow_run_id."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock the database query to return no messages with workflow_run_id
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_status_count_batch_loading_implementation(self):
|
||||
"""Test that status_count uses batch loading instead of N+1 queries."""
|
||||
# Arrange
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
# Create workflow run IDs
|
||||
workflow_run_id_1 = str(uuid4())
|
||||
workflow_run_id_2 = str(uuid4())
|
||||
workflow_run_id_3 = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock messages with workflow_run_id
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_1,
|
||||
),
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_2,
|
||||
),
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_3,
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow runs with different statuses
|
||||
mock_workflow_runs = [
|
||||
MagicMock(
|
||||
id=workflow_run_id_1,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
MagicMock(
|
||||
id=workflow_run_id_2,
|
||||
status=WorkflowExecutionStatus.FAILED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
MagicMock(
|
||||
id=workflow_run_id_3,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
]
|
||||
|
||||
# Track database calls
|
||||
calls_made = []
|
||||
|
||||
def mock_scalars(query):
|
||||
calls_made.append(str(query))
|
||||
mock_result = MagicMock()
|
||||
|
||||
# Return messages for the first query (messages with workflow_run_id)
|
||||
if "messages" in str(query) and "conversation_id" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
# Return workflow runs for the batch query
|
||||
elif "workflow_runs" in str(query):
|
||||
mock_result.all.return_value = mock_workflow_runs
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
|
||||
return mock_result
|
||||
|
||||
# Act & Assert
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
result = conversation.status_count
|
||||
|
||||
# Verify only 2 database queries were made (not N+1)
|
||||
assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}"
|
||||
|
||||
# Verify the first query gets messages
|
||||
assert "messages" in calls_made[0]
|
||||
assert "conversation_id" in calls_made[0]
|
||||
|
||||
# Verify the second query batch loads workflow runs with proper filtering
|
||||
assert "workflow_runs" in calls_made[1]
|
||||
assert "app_id" in calls_made[1] # Security filter applied
|
||||
assert "IN" in calls_made[1] # Batch loading with IN clause
|
||||
|
||||
# Verify correct status counts
|
||||
assert result["success"] == 1 # One SUCCEEDED
|
||||
assert result["failed"] == 1 # One FAILED
|
||||
assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED
|
||||
|
||||
def test_status_count_app_id_filtering(self):
|
||||
"""Test that status_count filters workflow runs by app_id for security."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
other_app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock message with workflow_run_id
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
),
|
||||
]
|
||||
|
||||
calls_made = []
|
||||
|
||||
def mock_scalars(query):
|
||||
calls_made.append(str(query))
|
||||
mock_result = MagicMock()
|
||||
|
||||
if "messages" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
elif "workflow_runs" in str(query):
|
||||
# Return empty list because no workflow run matches the correct app_id
|
||||
mock_result.all.return_value = [] # Workflow run filtered out by app_id
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
|
||||
return mock_result
|
||||
|
||||
# Act
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert - query should include app_id filter
|
||||
workflow_query = calls_made[1]
|
||||
assert "app_id" in workflow_query
|
||||
|
||||
# Since workflow run has wrong app_id, it shouldn't be included in counts
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
|
||||
def test_status_count_handles_invalid_workflow_status(self):
|
||||
"""Test that status_count gracefully handles invalid workflow status values."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow run with invalid status
|
||||
mock_workflow_runs = [
|
||||
MagicMock(
|
||||
id=workflow_run_id,
|
||||
status="invalid_status", # Invalid status that should raise ValueError
|
||||
app_id=app_id,
|
||||
),
|
||||
]
|
||||
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
# Mock the messages query
|
||||
def mock_scalars_side_effect(query):
|
||||
mock_result = MagicMock()
|
||||
if "messages" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
elif "workflow_runs" in str(query):
|
||||
mock_result.all.return_value = mock_workflow_runs
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
return mock_result
|
||||
|
||||
mock_scalars.side_effect = mock_scalars_side_effect
|
||||
|
||||
# Act - should not raise exception
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert - should handle invalid status gracefully
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
|
||||
@ -0,0 +1,88 @@
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod
|
||||
from models.provider import ProviderType
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class _FakeConfigurations:
|
||||
def __init__(self, provider_configuration: types.SimpleNamespace) -> None:
|
||||
self._provider_configuration = provider_configuration
|
||||
|
||||
def values(self) -> list[types.SimpleNamespace]:
|
||||
return [self._provider_configuration]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_with_fake_configurations():
|
||||
# Build a fake provider schema with minimal fields used by ProviderResponse
|
||||
fake_provider = types.SimpleNamespace(
|
||||
provider="langgenius/openai_api_compatible/openai_api_compatible",
|
||||
label=I18nObject(en_US="OpenAI API Compatible", zh_Hans="OpenAI API Compatible"),
|
||||
description=None,
|
||||
icon_small=None,
|
||||
icon_small_dark=None,
|
||||
icon_large=None,
|
||||
background=None,
|
||||
help=None,
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL],
|
||||
provider_credential_schema=None,
|
||||
model_credential_schema=None,
|
||||
)
|
||||
|
||||
# Include decrypted credentials to simulate the leak source
|
||||
custom_model = CustomModelConfiguration(
|
||||
model="gpt-4o-mini",
|
||||
model_type=ModelType.LLM,
|
||||
credentials={"api_key": "sk-plain-text", "endpoint": "https://example.com"},
|
||||
current_credential_id="cred-1",
|
||||
current_credential_name="API KEY 1",
|
||||
available_model_credentials=[],
|
||||
unadded_to_model_list=False,
|
||||
)
|
||||
|
||||
fake_custom_provider = types.SimpleNamespace(
|
||||
current_credential_id="cred-1",
|
||||
current_credential_name="API KEY 1",
|
||||
available_credentials=[CredentialConfiguration(credential_id="cred-1", credential_name="API KEY 1")],
|
||||
)
|
||||
|
||||
fake_custom_configuration = types.SimpleNamespace(
|
||||
provider=fake_custom_provider, models=[custom_model], can_added_models=[]
|
||||
)
|
||||
|
||||
fake_system_configuration = types.SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[])
|
||||
|
||||
fake_provider_configuration = types.SimpleNamespace(
|
||||
provider=fake_provider,
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
custom_configuration=fake_custom_configuration,
|
||||
system_configuration=fake_system_configuration,
|
||||
is_custom_configuration_available=lambda: True,
|
||||
)
|
||||
|
||||
class _FakeProviderManager:
|
||||
def get_configurations(self, tenant_id: str) -> _FakeConfigurations:
|
||||
return _FakeConfigurations(fake_provider_configuration)
|
||||
|
||||
svc = ModelProviderService()
|
||||
svc.provider_manager = _FakeProviderManager()
|
||||
return svc
|
||||
|
||||
|
||||
def test_get_provider_list_strips_credentials(service_with_fake_configurations: ModelProviderService):
|
||||
providers = service_with_fake_configurations.get_provider_list(tenant_id="tenant-1", model_type=None)
|
||||
|
||||
assert len(providers) == 1
|
||||
custom_models = providers[0].custom_configuration.custom_models
|
||||
|
||||
assert custom_models is not None
|
||||
assert len(custom_models) == 1
|
||||
# The sanitizer should drop credentials in list response
|
||||
assert custom_models[0].credentials is None
|
||||
@ -14,6 +14,7 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
("Hello, World!", "Hello, World!"),
|
||||
("", ""),
|
||||
(" ", " "),
|
||||
("【测试】", "【测试】"),
|
||||
],
|
||||
)
|
||||
def test_remove_leading_symbols(input_text, expected_output):
|
||||
|
||||
10
api/uv.lock
generated
10
api/uv.lock
generated
@ -1337,7 +1337,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.10.1"
|
||||
version = "1.11.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "apscheduler" },
|
||||
@ -1681,7 +1681,7 @@ dev = [
|
||||
{ name = "types-redis", specifier = ">=4.6.0.20241004" },
|
||||
{ name = "types-regex", specifier = "~=2024.11.6" },
|
||||
{ name = "types-setuptools", specifier = ">=80.9.0" },
|
||||
{ name = "types-shapely", specifier = "~=2.0.0" },
|
||||
{ name = "types-shapely", specifier = "~=2.1.0" },
|
||||
{ name = "types-simplejson", specifier = ">=3.20.0" },
|
||||
{ name = "types-six", specifier = ">=1.17.0" },
|
||||
{ name = "types-tensorflow", specifier = ">=2.18.0" },
|
||||
@ -6557,14 +6557,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "types-shapely"
|
||||
version = "2.0.0.20250404"
|
||||
version = "2.1.0.20250917"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4e/55/c71a25fd3fc9200df4d0b5fd2f6d74712a82f9a8bbdd90cefb9e6aee39dd/types_shapely-2.0.0.20250404.tar.gz", hash = "sha256:863f540b47fa626c33ae64eae06df171f9ab0347025d4458d2df496537296b4f", size = 25066, upload-time = "2025-04-04T02:54:30.592Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fa/19/7f28b10994433d43b9caa66f3b9bd6a0a9192b7ce8b5a7fc41534e54b821/types_shapely-2.1.0.20250917.tar.gz", hash = "sha256:5c56670742105aebe40c16414390d35fcaa55d6f774d328c1a18273ab0e2134a", size = 26363, upload-time = "2025-09-17T02:47:44.604Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/ff/7f4d414eb81534ba2476f3d54f06f1463c2ebf5d663fd10cff16ba607dd6/types_shapely-2.0.0.20250404-py3-none-any.whl", hash = "sha256:170fb92f5c168a120db39b3287697fdec5c93ef3e1ad15e52552c36b25318821", size = 36350, upload-time = "2025-04-04T02:54:29.506Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/a9/554ac40810e530263b6163b30a2b623bc16aae3fb64416f5d2b3657d0729/types_shapely-2.1.0.20250917-py3-none-any.whl", hash = "sha256:9334a79339504d39b040426be4938d422cec419168414dc74972aa746a8bf3a1", size = 37813, upload-time = "2025-09-17T02:47:43.788Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@ -1129,6 +1129,9 @@ WEAVIATE_AUTHENTICATION_APIKEY_USERS=hello@dify.ai
|
||||
WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED=true
|
||||
WEAVIATE_AUTHORIZATION_ADMINLIST_USERS=hello@dify.ai
|
||||
WEAVIATE_DISABLE_TELEMETRY=false
|
||||
WEAVIATE_ENABLE_TOKENIZER_GSE=false
|
||||
WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA=false
|
||||
WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR=false
|
||||
|
||||
# ------------------------------
|
||||
# Environment Variables for Chroma
|
||||
@ -1429,3 +1432,6 @@ WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# The API key of amplitude
|
||||
AMPLITUDE_API_KEY=
|
||||
|
||||
@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -62,7 +62,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -101,7 +101,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -131,11 +131,12 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.10.1-fix.1
|
||||
image: langgenius/dify-web:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
APP_API_URL: ${APP_API_URL:-}
|
||||
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
|
||||
SENTRY_DSN: ${WEB_SENTRY_DSN:-}
|
||||
NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0}
|
||||
@ -268,7 +269,7 @@ services:
|
||||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.4.1-local
|
||||
image: langgenius/dify-plugin-daemon:0.5.1-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -451,6 +452,9 @@ services:
|
||||
AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true}
|
||||
AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai}
|
||||
DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false}
|
||||
ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false}
|
||||
ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false}
|
||||
ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false}
|
||||
|
||||
# OceanBase vector database
|
||||
oceanbase:
|
||||
|
||||
@ -123,7 +123,7 @@ services:
|
||||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.4.1-local
|
||||
image: langgenius/dify-plugin-daemon:0.5.1-local
|
||||
restart: always
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
|
||||
@ -479,6 +479,9 @@ x-shared-env: &shared-api-worker-env
|
||||
WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true}
|
||||
WEAVIATE_AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai}
|
||||
WEAVIATE_DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false}
|
||||
WEAVIATE_ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false}
|
||||
WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false}
|
||||
WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false}
|
||||
CHROMA_SERVER_AUTHN_CREDENTIALS: ${CHROMA_SERVER_AUTHN_CREDENTIALS:-difyai123456}
|
||||
CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider}
|
||||
CHROMA_IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE}
|
||||
@ -632,6 +635,7 @@ x-shared-env: &shared-api-worker-env
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: ${WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE:-100}
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: ${WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK:-0}
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1}
|
||||
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
|
||||
|
||||
services:
|
||||
# Init container to fix permissions
|
||||
@ -655,7 +659,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -696,7 +700,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -735,7 +739,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.10.1-fix.1
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -765,11 +769,12 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.10.1-fix.1
|
||||
image: langgenius/dify-web:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
APP_API_URL: ${APP_API_URL:-}
|
||||
AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-}
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
|
||||
SENTRY_DSN: ${WEB_SENTRY_DSN:-}
|
||||
NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0}
|
||||
@ -902,7 +907,7 @@ services:
|
||||
|
||||
# plugin daemon
|
||||
plugin_daemon:
|
||||
image: langgenius/dify-plugin-daemon:0.4.1-local
|
||||
image: langgenius/dify-plugin-daemon:0.5.1-local
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@ -1085,6 +1090,9 @@ services:
|
||||
AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true}
|
||||
AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai}
|
||||
DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false}
|
||||
ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false}
|
||||
ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false}
|
||||
ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false}
|
||||
|
||||
# OceanBase vector database
|
||||
oceanbase:
|
||||
|
||||
@ -70,3 +70,6 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false
|
||||
|
||||
# The maximum number of tree node depth for workflow
|
||||
NEXT_PUBLIC_MAX_TREE_DEPTH=50
|
||||
|
||||
# The API key of amplitude
|
||||
NEXT_PUBLIC_AMPLITUDE_API_KEY=
|
||||
|
||||
@ -42,6 +42,7 @@ import type { InputVar, Variable } from '@/app/components/workflow/types'
|
||||
import { appDefaultIconBackground } from '@/config'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now'
|
||||
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { fetchAppDetailDirect } from '@/service/apps'
|
||||
@ -153,6 +154,7 @@ const AppPublisher = ({
|
||||
|
||||
const { data: userCanAccessApp, isLoading: isGettingUserCanAccessApp, refetch } = useGetUserCanAccessApp({ appId: appDetail?.id, enabled: false })
|
||||
const { data: appAccessSubjects, isLoading: isGettingAppWhiteListSubjects } = useAppWhiteListSubjects(appDetail?.id, open && systemFeatures.webapp_auth.enabled && appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS)
|
||||
const openAsyncWindow = useAsyncWindowOpen()
|
||||
|
||||
const noAccessPermission = useMemo(() => systemFeatures.webapp_auth.enabled && appDetail && appDetail.access_mode !== AccessMode.EXTERNAL_MEMBERS && !userCanAccessApp?.result, [systemFeatures, appDetail, userCanAccessApp])
|
||||
const disabledFunctionButton = useMemo(() => (!publishedAt || missingStartNode || noAccessPermission), [publishedAt, missingStartNode, noAccessPermission])
|
||||
@ -217,17 +219,19 @@ const AppPublisher = ({
|
||||
}, [disabled, onToggle, open])
|
||||
|
||||
const handleOpenInExplore = useCallback(async () => {
|
||||
try {
|
||||
await openAsyncWindow(async () => {
|
||||
if (!appDetail?.id)
|
||||
throw new Error('App not found')
|
||||
const { installed_apps }: any = await fetchInstalledAppList(appDetail?.id) || {}
|
||||
if (installed_apps?.length > 0)
|
||||
window.open(`${basePath}/explore/installed/${installed_apps[0].id}`, '_blank')
|
||||
else
|
||||
throw new Error('No app found in Explore')
|
||||
}
|
||||
catch (e: any) {
|
||||
Toast.notify({ type: 'error', message: `${e.message || e}` })
|
||||
}
|
||||
}, [appDetail?.id])
|
||||
return `${basePath}/explore/installed/${installed_apps[0].id}`
|
||||
throw new Error('No app found in Explore')
|
||||
}, {
|
||||
onError: (err) => {
|
||||
Toast.notify({ type: 'error', message: `${err.message || err}` })
|
||||
},
|
||||
})
|
||||
}, [appDetail?.id, openAsyncWindow])
|
||||
|
||||
const handleAccessControlUpdate = useCallback(async () => {
|
||||
if (!appDetail)
|
||||
|
||||
@ -0,0 +1,49 @@
|
||||
import React from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import ConfirmAddVar from './index'
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('../../base/var-highlight', () => ({
|
||||
__esModule: true,
|
||||
default: ({ name }: { name: string }) => <span data-testid="var-highlight">{name}</span>,
|
||||
}))
|
||||
|
||||
describe('ConfirmAddVar', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render variable names', () => {
|
||||
render(<ConfirmAddVar varNameArr={['foo', 'bar']} onConfirm={jest.fn()} onCancel={jest.fn()} onHide={jest.fn()} />)
|
||||
|
||||
const highlights = screen.getAllByTestId('var-highlight')
|
||||
expect(highlights).toHaveLength(2)
|
||||
expect(highlights[0]).toHaveTextContent('foo')
|
||||
expect(highlights[1]).toHaveTextContent('bar')
|
||||
})
|
||||
|
||||
it('should trigger cancel actions', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
render(<ConfirmAddVar varNameArr={['foo']} onConfirm={onConfirm} onCancel={onCancel} onHide={jest.fn()} />)
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.cancel'))
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should trigger confirm actions', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
render(<ConfirmAddVar varNameArr={['foo']} onConfirm={onConfirm} onCancel={onCancel} onHide={jest.fn()} />)
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.add'))
|
||||
|
||||
expect(onConfirm).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,56 @@
|
||||
import React from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import EditModal from './edit-modal'
|
||||
import type { ConversationHistoriesRole } from '@/models/debug'
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/base/modal', () => ({
|
||||
__esModule: true,
|
||||
default: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
|
||||
}))
|
||||
|
||||
describe('Conversation history edit modal', () => {
|
||||
const data: ConversationHistoriesRole = {
|
||||
user_prefix: 'user',
|
||||
assistant_prefix: 'assistant',
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render provided prefixes', () => {
|
||||
render(<EditModal isShow saveLoading={false} data={data} onClose={jest.fn()} onSave={jest.fn()} />)
|
||||
|
||||
expect(screen.getByDisplayValue('user')).toBeInTheDocument()
|
||||
expect(screen.getByDisplayValue('assistant')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update prefixes and save changes', () => {
|
||||
const onSave = jest.fn()
|
||||
render(<EditModal isShow saveLoading={false} data={data} onClose={jest.fn()} onSave={onSave} />)
|
||||
|
||||
fireEvent.change(screen.getByDisplayValue('user'), { target: { value: 'member' } })
|
||||
fireEvent.change(screen.getByDisplayValue('assistant'), { target: { value: 'helper' } })
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
|
||||
expect(onSave).toHaveBeenCalledWith({
|
||||
user_prefix: 'member',
|
||||
assistant_prefix: 'helper',
|
||||
})
|
||||
})
|
||||
|
||||
it('should call close handler', () => {
|
||||
const onClose = jest.fn()
|
||||
render(<EditModal isShow saveLoading={false} data={data} onClose={onClose} onSave={jest.fn()} />)
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.cancel'))
|
||||
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,48 @@
|
||||
import React from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import HistoryPanel from './history-panel'
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockDocLink = jest.fn(() => 'doc-link')
|
||||
jest.mock('@/context/i18n', () => ({
|
||||
useDocLink: () => mockDocLink,
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/app/configuration/base/operation-btn', () => ({
|
||||
__esModule: true,
|
||||
default: ({ onClick }: { onClick: () => void }) => (
|
||||
<button type="button" data-testid="edit-button" onClick={onClick}>
|
||||
edit
|
||||
</button>
|
||||
),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/app/configuration/base/feature-panel', () => ({
|
||||
__esModule: true,
|
||||
default: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
|
||||
}))
|
||||
|
||||
describe('HistoryPanel', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render warning content and link when showWarning is true', () => {
|
||||
render(<HistoryPanel showWarning onShowEditModal={jest.fn()} />)
|
||||
|
||||
expect(screen.getByText('appDebug.feature.conversationHistory.tip')).toBeInTheDocument()
|
||||
const link = screen.getByText('appDebug.feature.conversationHistory.learnMore')
|
||||
expect(link).toHaveAttribute('href', 'doc-link')
|
||||
})
|
||||
|
||||
it('should hide warning when showWarning is false', () => {
|
||||
render(<HistoryPanel showWarning={false} onShowEditModal={jest.fn()} />)
|
||||
|
||||
expect(screen.queryByText('appDebug.feature.conversationHistory.tip')).toBeNull()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,351 @@
|
||||
import React from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import Prompt, { type IPromptProps } from './index'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { MAX_PROMPT_MESSAGE_LENGTH } from '@/config'
|
||||
import { type PromptItem, PromptRole, type PromptVariable } from '@/models/debug'
|
||||
import { AppModeEnum, ModelModeType } from '@/types/app'
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
type DebugConfiguration = {
|
||||
isAdvancedMode: boolean
|
||||
currentAdvancedPrompt: PromptItem | PromptItem[]
|
||||
setCurrentAdvancedPrompt: (prompt: PromptItem | PromptItem[], isUserChanged?: boolean) => void
|
||||
modelModeType: ModelModeType
|
||||
dataSets: Array<{
|
||||
id: string
|
||||
name?: string
|
||||
}>
|
||||
hasSetBlockStatus: {
|
||||
context: boolean
|
||||
history: boolean
|
||||
query: boolean
|
||||
}
|
||||
}
|
||||
|
||||
const defaultPromptVariables: PromptVariable[] = [
|
||||
{ key: 'var', name: 'Variable', type: 'string', required: true },
|
||||
]
|
||||
|
||||
let mockSimplePromptInputProps: IPromptProps | null = null
|
||||
|
||||
jest.mock('./simple-prompt-input', () => ({
|
||||
__esModule: true,
|
||||
default: (props: IPromptProps) => {
|
||||
mockSimplePromptInputProps = props
|
||||
return (
|
||||
<div
|
||||
data-testid="simple-prompt-input"
|
||||
data-mode={props.mode}
|
||||
data-template={props.promptTemplate}
|
||||
data-readonly={props.readonly ?? false}
|
||||
data-no-title={props.noTitle ?? false}
|
||||
data-gradient-border={props.gradientBorder ?? false}
|
||||
data-editor-height={props.editorHeight ?? ''}
|
||||
data-no-resize={props.noResize ?? false}
|
||||
onClick={() => props.onChange?.('mocked prompt', props.promptVariables)}
|
||||
>
|
||||
SimplePromptInput Mock
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
type AdvancedMessageInputProps = {
|
||||
isChatMode: boolean
|
||||
type: PromptRole
|
||||
value: string
|
||||
onTypeChange: (value: PromptRole) => void
|
||||
canDelete: boolean
|
||||
onDelete: () => void
|
||||
onChange: (value: string) => void
|
||||
promptVariables: PromptVariable[]
|
||||
isContextMissing: boolean
|
||||
onHideContextMissingTip: () => void
|
||||
noResize?: boolean
|
||||
}
|
||||
|
||||
jest.mock('./advanced-prompt-input', () => ({
|
||||
__esModule: true,
|
||||
default: (props: AdvancedMessageInputProps) => {
|
||||
return (
|
||||
<div
|
||||
data-testid="advanced-message-input"
|
||||
data-type={props.type}
|
||||
data-value={props.value}
|
||||
data-chat-mode={props.isChatMode}
|
||||
data-can-delete={props.canDelete}
|
||||
data-context-missing={props.isContextMissing}
|
||||
>
|
||||
<button type="button" onClick={() => props.onChange('updated text')}>
|
||||
change
|
||||
</button>
|
||||
<button type="button" onClick={() => props.onTypeChange(PromptRole.assistant)}>
|
||||
type
|
||||
</button>
|
||||
<button type="button" onClick={props.onDelete}>
|
||||
delete
|
||||
</button>
|
||||
<button type="button" onClick={props.onHideContextMissingTip}>
|
||||
hide-context
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
const getContextValue = (overrides: Partial<DebugConfiguration> = {}): DebugConfiguration => {
|
||||
return {
|
||||
setCurrentAdvancedPrompt: jest.fn(),
|
||||
isAdvancedMode: false,
|
||||
currentAdvancedPrompt: [],
|
||||
modelModeType: ModelModeType.chat,
|
||||
dataSets: [],
|
||||
hasSetBlockStatus: {
|
||||
context: false,
|
||||
history: false,
|
||||
query: false,
|
||||
},
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
const renderComponent = (
|
||||
props: Partial<IPromptProps> = {},
|
||||
contextOverrides: Partial<DebugConfiguration> = {},
|
||||
) => {
|
||||
const mergedProps: IPromptProps = {
|
||||
mode: AppModeEnum.CHAT,
|
||||
promptTemplate: 'initial template',
|
||||
promptVariables: defaultPromptVariables,
|
||||
onChange: jest.fn(),
|
||||
...props,
|
||||
}
|
||||
const contextValue = getContextValue(contextOverrides)
|
||||
|
||||
return {
|
||||
contextValue,
|
||||
...render(
|
||||
<ConfigContext.Provider value={contextValue as any}>
|
||||
<Prompt {...mergedProps} />
|
||||
</ConfigContext.Provider>,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
describe('Prompt config component', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockSimplePromptInputProps = null
|
||||
})
|
||||
|
||||
// Rendering simple mode
|
||||
it('should render simple prompt when advanced mode is disabled', () => {
|
||||
const onChange = jest.fn()
|
||||
renderComponent({ onChange }, { isAdvancedMode: false })
|
||||
|
||||
const simplePrompt = screen.getByTestId('simple-prompt-input')
|
||||
expect(simplePrompt).toBeInTheDocument()
|
||||
expect(simplePrompt).toHaveAttribute('data-mode', AppModeEnum.CHAT)
|
||||
expect(mockSimplePromptInputProps?.promptTemplate).toBe('initial template')
|
||||
fireEvent.click(simplePrompt)
|
||||
expect(onChange).toHaveBeenCalledWith('mocked prompt', defaultPromptVariables)
|
||||
expect(screen.queryByTestId('advanced-message-input')).toBeNull()
|
||||
})
|
||||
|
||||
// Rendering advanced chat messages
|
||||
it('should render advanced chat prompts and show context missing tip when dataset context is not set', () => {
|
||||
const currentAdvancedPrompt: PromptItem[] = [
|
||||
{ role: PromptRole.user, text: 'first' },
|
||||
{ role: PromptRole.assistant, text: 'second' },
|
||||
]
|
||||
renderComponent(
|
||||
{},
|
||||
{
|
||||
isAdvancedMode: true,
|
||||
currentAdvancedPrompt,
|
||||
modelModeType: ModelModeType.chat,
|
||||
dataSets: [{ id: 'ds' } as unknown as DebugConfiguration['dataSets'][number]],
|
||||
hasSetBlockStatus: { context: false, history: true, query: true },
|
||||
},
|
||||
)
|
||||
|
||||
const renderedMessages = screen.getAllByTestId('advanced-message-input')
|
||||
expect(renderedMessages).toHaveLength(2)
|
||||
expect(renderedMessages[0]).toHaveAttribute('data-context-missing', 'true')
|
||||
fireEvent.click(screen.getAllByText('hide-context')[0])
|
||||
expect(screen.getAllByTestId('advanced-message-input')[0]).toHaveAttribute('data-context-missing', 'false')
|
||||
})
|
||||
|
||||
// Chat message mutations
|
||||
it('should update chat prompt value and call setter with user change flag', () => {
|
||||
const currentAdvancedPrompt: PromptItem[] = [
|
||||
{ role: PromptRole.user, text: 'first' },
|
||||
{ role: PromptRole.assistant, text: 'second' },
|
||||
]
|
||||
const setCurrentAdvancedPrompt = jest.fn()
|
||||
renderComponent(
|
||||
{},
|
||||
{
|
||||
isAdvancedMode: true,
|
||||
currentAdvancedPrompt,
|
||||
modelModeType: ModelModeType.chat,
|
||||
setCurrentAdvancedPrompt,
|
||||
},
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getAllByText('change')[0])
|
||||
expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith(
|
||||
[
|
||||
{ role: PromptRole.user, text: 'updated text' },
|
||||
{ role: PromptRole.assistant, text: 'second' },
|
||||
],
|
||||
true,
|
||||
)
|
||||
})
|
||||
|
||||
it('should update chat prompt role when type changes', () => {
|
||||
const currentAdvancedPrompt: PromptItem[] = [
|
||||
{ role: PromptRole.user, text: 'first' },
|
||||
{ role: PromptRole.user, text: 'second' },
|
||||
]
|
||||
const setCurrentAdvancedPrompt = jest.fn()
|
||||
renderComponent(
|
||||
{},
|
||||
{
|
||||
isAdvancedMode: true,
|
||||
currentAdvancedPrompt,
|
||||
modelModeType: ModelModeType.chat,
|
||||
setCurrentAdvancedPrompt,
|
||||
},
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getAllByText('type')[1])
|
||||
expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith(
|
||||
[
|
||||
{ role: PromptRole.user, text: 'first' },
|
||||
{ role: PromptRole.assistant, text: 'second' },
|
||||
],
|
||||
)
|
||||
})
|
||||
|
||||
it('should delete chat prompt item', () => {
|
||||
const currentAdvancedPrompt: PromptItem[] = [
|
||||
{ role: PromptRole.user, text: 'first' },
|
||||
{ role: PromptRole.assistant, text: 'second' },
|
||||
]
|
||||
const setCurrentAdvancedPrompt = jest.fn()
|
||||
renderComponent(
|
||||
{},
|
||||
{
|
||||
isAdvancedMode: true,
|
||||
currentAdvancedPrompt,
|
||||
modelModeType: ModelModeType.chat,
|
||||
setCurrentAdvancedPrompt,
|
||||
},
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getAllByText('delete')[0])
|
||||
expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith([{ role: PromptRole.assistant, text: 'second' }])
|
||||
})
|
||||
|
||||
// Add message behavior
|
||||
it('should append a mirrored role message when clicking add in chat mode', () => {
|
||||
const currentAdvancedPrompt: PromptItem[] = [
|
||||
{ role: PromptRole.user, text: 'first' },
|
||||
]
|
||||
const setCurrentAdvancedPrompt = jest.fn()
|
||||
renderComponent(
|
||||
{},
|
||||
{
|
||||
isAdvancedMode: true,
|
||||
currentAdvancedPrompt,
|
||||
modelModeType: ModelModeType.chat,
|
||||
setCurrentAdvancedPrompt,
|
||||
},
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('appDebug.promptMode.operation.addMessage'))
|
||||
expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith([
|
||||
{ role: PromptRole.user, text: 'first' },
|
||||
{ role: PromptRole.assistant, text: '' },
|
||||
])
|
||||
})
|
||||
|
||||
it('should append a user role when the last chat prompt is from assistant', () => {
|
||||
const currentAdvancedPrompt: PromptItem[] = [
|
||||
{ role: PromptRole.assistant, text: 'reply' },
|
||||
]
|
||||
const setCurrentAdvancedPrompt = jest.fn()
|
||||
renderComponent(
|
||||
{},
|
||||
{
|
||||
isAdvancedMode: true,
|
||||
currentAdvancedPrompt,
|
||||
modelModeType: ModelModeType.chat,
|
||||
setCurrentAdvancedPrompt,
|
||||
},
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('appDebug.promptMode.operation.addMessage'))
|
||||
expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith([
|
||||
{ role: PromptRole.assistant, text: 'reply' },
|
||||
{ role: PromptRole.user, text: '' },
|
||||
])
|
||||
})
|
||||
|
||||
it('should insert a system message when adding to an empty chat prompt list', () => {
|
||||
const setCurrentAdvancedPrompt = jest.fn()
|
||||
renderComponent(
|
||||
{},
|
||||
{
|
||||
isAdvancedMode: true,
|
||||
currentAdvancedPrompt: [],
|
||||
modelModeType: ModelModeType.chat,
|
||||
setCurrentAdvancedPrompt,
|
||||
},
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('appDebug.promptMode.operation.addMessage'))
|
||||
expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith([{ role: PromptRole.system, text: '' }])
|
||||
})
|
||||
|
||||
it('should not show add button when reaching max prompt length', () => {
|
||||
const prompts: PromptItem[] = Array.from({ length: MAX_PROMPT_MESSAGE_LENGTH }, (_, index) => ({
|
||||
role: PromptRole.user,
|
||||
text: `item-${index}`,
|
||||
}))
|
||||
renderComponent(
|
||||
{},
|
||||
{
|
||||
isAdvancedMode: true,
|
||||
currentAdvancedPrompt: prompts,
|
||||
modelModeType: ModelModeType.chat,
|
||||
},
|
||||
)
|
||||
|
||||
expect(screen.queryByText('appDebug.promptMode.operation.addMessage')).toBeNull()
|
||||
})
|
||||
|
||||
// Completion mode
|
||||
it('should update completion prompt value and flag as user change', () => {
|
||||
const setCurrentAdvancedPrompt = jest.fn()
|
||||
renderComponent(
|
||||
{},
|
||||
{
|
||||
isAdvancedMode: true,
|
||||
currentAdvancedPrompt: { role: PromptRole.user, text: 'single' },
|
||||
modelModeType: ModelModeType.completion,
|
||||
setCurrentAdvancedPrompt,
|
||||
},
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('change'))
|
||||
|
||||
expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith({ role: PromptRole.user, text: 'updated text' }, true)
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,37 @@
|
||||
import React from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import MessageTypeSelector from './message-type-selector'
|
||||
import { PromptRole } from '@/models/debug'
|
||||
|
||||
describe('MessageTypeSelector', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render current value and keep options hidden by default', () => {
|
||||
render(<MessageTypeSelector value={PromptRole.user} onChange={jest.fn()} />)
|
||||
|
||||
expect(screen.getByText(PromptRole.user)).toBeInTheDocument()
|
||||
expect(screen.queryByText(PromptRole.system)).toBeNull()
|
||||
})
|
||||
|
||||
it('should toggle option list when clicking the selector', () => {
|
||||
render(<MessageTypeSelector value={PromptRole.system} onChange={jest.fn()} />)
|
||||
|
||||
fireEvent.click(screen.getByText(PromptRole.system))
|
||||
|
||||
expect(screen.getByText(PromptRole.user)).toBeInTheDocument()
|
||||
expect(screen.getByText(PromptRole.assistant)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onChange with selected type and close the list', () => {
|
||||
const onChange = jest.fn()
|
||||
render(<MessageTypeSelector value={PromptRole.assistant} onChange={onChange} />)
|
||||
|
||||
fireEvent.click(screen.getByText(PromptRole.assistant))
|
||||
fireEvent.click(screen.getByText(PromptRole.user))
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith(PromptRole.user)
|
||||
expect(screen.queryByText(PromptRole.system)).toBeNull()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,66 @@
|
||||
import React from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap'
|
||||
|
||||
describe('PromptEditorHeightResizeWrap', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
jest.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
jest.runOnlyPendingTimers()
|
||||
jest.useRealTimers()
|
||||
})
|
||||
|
||||
it('should render children, footer, and hide resize handler when requested', () => {
|
||||
const { container } = render(
|
||||
<PromptEditorHeightResizeWrap
|
||||
className="wrapper"
|
||||
height={150}
|
||||
minHeight={100}
|
||||
onHeightChange={jest.fn()}
|
||||
footer={<div>footer</div>}
|
||||
hideResize
|
||||
>
|
||||
<div>content</div>
|
||||
</PromptEditorHeightResizeWrap>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('content')).toBeInTheDocument()
|
||||
expect(screen.getByText('footer')).toBeInTheDocument()
|
||||
expect(container.querySelector('.cursor-row-resize')).toBeNull()
|
||||
})
|
||||
|
||||
it('should resize height with mouse events and clamp to minHeight', () => {
|
||||
const onHeightChange = jest.fn()
|
||||
|
||||
const { container } = render(
|
||||
<PromptEditorHeightResizeWrap
|
||||
height={150}
|
||||
minHeight={100}
|
||||
onHeightChange={onHeightChange}
|
||||
>
|
||||
<div>content</div>
|
||||
</PromptEditorHeightResizeWrap>,
|
||||
)
|
||||
|
||||
const handle = container.querySelector('.cursor-row-resize')
|
||||
expect(handle).not.toBeNull()
|
||||
|
||||
fireEvent.mouseDown(handle as Element, { clientY: 100 })
|
||||
expect(document.body.style.userSelect).toBe('none')
|
||||
|
||||
fireEvent.mouseMove(document, { clientY: 130 })
|
||||
jest.runAllTimers()
|
||||
expect(onHeightChange).toHaveBeenLastCalledWith(180)
|
||||
|
||||
onHeightChange.mockClear()
|
||||
fireEvent.mouseMove(document, { clientY: -100 })
|
||||
jest.runAllTimers()
|
||||
expect(onHeightChange).toHaveBeenLastCalledWith(100)
|
||||
|
||||
fireEvent.mouseUp(document)
|
||||
expect(document.body.style.userSelect).toBe('')
|
||||
})
|
||||
})
|
||||
@ -32,6 +32,7 @@ import { canFindTool } from '@/utils'
|
||||
import { useAllBuiltInTools, useAllCustomTools, useAllMCPTools, useAllWorkflowTools } from '@/service/use-tools'
|
||||
import type { ToolWithProvider } from '@/app/components/workflow/types'
|
||||
import { useMittContextSelector } from '@/context/mitt-context'
|
||||
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
|
||||
|
||||
type AgentToolWithMoreInfo = AgentTool & { icon: any; collection?: Collection } | null
|
||||
const AgentTools: FC = () => {
|
||||
@ -93,13 +94,17 @@ const AgentTools: FC = () => {
|
||||
|
||||
const [isDeleting, setIsDeleting] = useState<number>(-1)
|
||||
const getToolValue = (tool: ToolDefaultValue) => {
|
||||
const currToolInCollections = collectionList.find(c => c.id === tool.provider_id)
|
||||
const currToolWithConfigs = currToolInCollections?.tools.find(t => t.name === tool.tool_name)
|
||||
const formSchemas = currToolWithConfigs ? toolParametersToFormSchemas(currToolWithConfigs.parameters) : []
|
||||
const paramsWithDefaultValue = addDefaultValue(tool.params, formSchemas)
|
||||
return {
|
||||
provider_id: tool.provider_id,
|
||||
provider_type: tool.provider_type as CollectionType,
|
||||
provider_name: tool.provider_name,
|
||||
tool_name: tool.tool_name,
|
||||
tool_label: tool.tool_label,
|
||||
tool_parameters: tool.params,
|
||||
tool_parameters: paramsWithDefaultValue,
|
||||
notAuthor: !tool.is_team_authorization,
|
||||
enabled: true,
|
||||
}
|
||||
@ -119,7 +124,7 @@ const AgentTools: FC = () => {
|
||||
}
|
||||
const getProviderShowName = (item: AgentTool) => {
|
||||
const type = item.provider_type
|
||||
if(type === CollectionType.builtIn)
|
||||
if (type === CollectionType.builtIn)
|
||||
return item.provider_name.split('/').pop()
|
||||
return item.provider_name
|
||||
}
|
||||
|
||||
@ -16,7 +16,7 @@ import Description from '@/app/components/plugins/card/base/description'
|
||||
import TabSlider from '@/app/components/base/tab-slider-plain'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
|
||||
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
|
||||
import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
|
||||
import type { Collection, Tool } from '@/app/components/tools/types'
|
||||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList, fetchWorkflowToolList } from '@/service/tools'
|
||||
@ -92,15 +92,11 @@ const SettingBuiltInTool: FC<Props> = ({
|
||||
}())
|
||||
})
|
||||
setTools(list)
|
||||
const currTool = list.find(tool => tool.name === toolName)
|
||||
if (currTool) {
|
||||
const formSchemas = toolParametersToFormSchemas(currTool.parameters)
|
||||
setTempSetting(addDefaultValue(setting, formSchemas))
|
||||
}
|
||||
}
|
||||
catch { }
|
||||
setIsLoading(false)
|
||||
})()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [collection?.name, collection?.id, collection?.type])
|
||||
|
||||
useEffect(() => {
|
||||
@ -249,7 +245,7 @@ const SettingBuiltInTool: FC<Props> = ({
|
||||
{!readonly && !isInfoActive && (
|
||||
<div className='flex shrink-0 justify-end space-x-2 rounded-b-[10px] bg-components-panel-bg py-2'>
|
||||
<Button className='flex h-8 items-center !px-3 !text-[13px] font-medium ' onClick={onHide}>{t('common.operation.cancel')}</Button>
|
||||
<Button className='flex h-8 items-center !px-3 !text-[13px] font-medium' variant='primary' disabled={!isValid} onClick={() => onSave?.(addDefaultValue(tempSetting, formSchemas))}>{t('common.operation.save')}</Button>
|
||||
<Button className='flex h-8 items-center !px-3 !text-[13px] font-medium' variant='primary' disabled={!isValid} onClick={() => onSave?.(tempSetting)}>{t('common.operation.save')}</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@ -0,0 +1,480 @@
|
||||
import '@testing-library/jest-dom'
|
||||
import type { CSSProperties } from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import DebugWithMultipleModel from './index'
|
||||
import type { DebugWithMultipleModelContextType } from './context'
|
||||
import { APP_CHAT_WITH_MULTIPLE_MODEL } from '../types'
|
||||
import type { ModelAndParameter } from '../types'
|
||||
import type { Inputs, ModelConfig } from '@/models/debug'
|
||||
import { DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
|
||||
import type { FeatureStoreState } from '@/app/components/base/features/store'
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import type { InputForm } from '@/app/components/base/chat/chat/type'
|
||||
import { AppModeEnum, ModelModeType, type PromptVariable, Resolution, TransferMethod } from '@/types/app'
|
||||
|
||||
type PromptVariableWithMeta = Omit<PromptVariable, 'type' | 'required'> & {
|
||||
type: PromptVariable['type'] | 'api'
|
||||
required?: boolean
|
||||
hide?: boolean
|
||||
}
|
||||
|
||||
const mockUseDebugConfigurationContext = jest.fn()
|
||||
const mockUseFeaturesSelector = jest.fn()
|
||||
const mockUseEventEmitterContext = jest.fn()
|
||||
const mockUseAppStoreSelector = jest.fn()
|
||||
const mockEventEmitter = { emit: jest.fn() }
|
||||
const mockSetShowAppConfigureFeaturesModal = jest.fn()
|
||||
let capturedChatInputProps: MockChatInputAreaProps | null = null
|
||||
let modelIdCounter = 0
|
||||
let featureState: FeatureStoreState
|
||||
|
||||
type MockChatInputAreaProps = {
|
||||
onSend?: (message: string, files?: FileEntity[]) => void
|
||||
onFeatureBarClick?: (state: boolean) => void
|
||||
showFeatureBar?: boolean
|
||||
showFileUpload?: boolean
|
||||
inputs?: Record<string, any>
|
||||
inputsForm?: InputForm[]
|
||||
speechToTextConfig?: unknown
|
||||
visionConfig?: unknown
|
||||
}
|
||||
|
||||
const mockFiles: FileEntity[] = [
|
||||
{
|
||||
id: 'file-1',
|
||||
name: 'file.txt',
|
||||
size: 10,
|
||||
type: 'text/plain',
|
||||
progress: 100,
|
||||
transferMethod: TransferMethod.remote_url,
|
||||
supportFileType: 'text',
|
||||
},
|
||||
]
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/context/debug-configuration', () => ({
|
||||
__esModule: true,
|
||||
useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/base/features/hooks', () => ({
|
||||
__esModule: true,
|
||||
useFeatures: (selector: (state: FeatureStoreState) => unknown) => mockUseFeaturesSelector(selector),
|
||||
}))
|
||||
|
||||
jest.mock('@/context/event-emitter', () => ({
|
||||
__esModule: true,
|
||||
useEventEmitterContextContext: () => mockUseEventEmitterContext(),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/app/store', () => ({
|
||||
__esModule: true,
|
||||
useStore: (selector: (state: { setShowAppConfigureFeaturesModal: typeof mockSetShowAppConfigureFeaturesModal }) => unknown) => mockUseAppStoreSelector(selector),
|
||||
}))
|
||||
|
||||
jest.mock('./debug-item', () => ({
|
||||
__esModule: true,
|
||||
default: ({
|
||||
modelAndParameter,
|
||||
className,
|
||||
style,
|
||||
}: {
|
||||
modelAndParameter: ModelAndParameter
|
||||
className?: string
|
||||
style?: CSSProperties
|
||||
}) => (
|
||||
<div
|
||||
data-testid='debug-item'
|
||||
data-model-id={modelAndParameter.id}
|
||||
className={className}
|
||||
style={style}
|
||||
>
|
||||
DebugItem-{modelAndParameter.id}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/base/chat/chat/chat-input-area', () => ({
|
||||
__esModule: true,
|
||||
default: (props: MockChatInputAreaProps) => {
|
||||
capturedChatInputProps = props
|
||||
return (
|
||||
<div data-testid='chat-input-area'>
|
||||
<button type='button' onClick={() => props.onSend?.('test message', mockFiles)}>send</button>
|
||||
<button type='button' onClick={() => props.onFeatureBarClick?.(true)}>feature</button>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
const createFeatureState = (): FeatureStoreState => ({
|
||||
features: {
|
||||
speech2text: { enabled: true },
|
||||
file: {
|
||||
image: {
|
||||
enabled: true,
|
||||
detail: Resolution.high,
|
||||
number_limits: 2,
|
||||
transfer_methods: [TransferMethod.remote_url],
|
||||
},
|
||||
},
|
||||
},
|
||||
setFeatures: jest.fn(),
|
||||
showFeaturesModal: false,
|
||||
setShowFeaturesModal: jest.fn(),
|
||||
})
|
||||
|
||||
const createModelConfig = (promptVariables: PromptVariableWithMeta[] = []): ModelConfig => ({
|
||||
provider: 'OPENAI',
|
||||
model_id: 'gpt-4',
|
||||
mode: ModelModeType.chat,
|
||||
configs: {
|
||||
prompt_template: '',
|
||||
prompt_variables: promptVariables as unknown as PromptVariable[],
|
||||
},
|
||||
chat_prompt_config: DEFAULT_CHAT_PROMPT_CONFIG,
|
||||
completion_prompt_config: DEFAULT_COMPLETION_PROMPT_CONFIG,
|
||||
opening_statement: '',
|
||||
more_like_this: null,
|
||||
suggested_questions: [],
|
||||
suggested_questions_after_answer: null,
|
||||
speech_to_text: null,
|
||||
text_to_speech: null,
|
||||
file_upload: null,
|
||||
retriever_resource: null,
|
||||
sensitive_word_avoidance: null,
|
||||
annotation_reply: null,
|
||||
external_data_tools: [],
|
||||
system_parameters: {
|
||||
audio_file_size_limit: 0,
|
||||
file_size_limit: 0,
|
||||
image_file_size_limit: 0,
|
||||
video_file_size_limit: 0,
|
||||
workflow_file_upload_limit: 0,
|
||||
},
|
||||
dataSets: [],
|
||||
agentConfig: DEFAULT_AGENT_SETTING,
|
||||
})
|
||||
|
||||
type DebugConfiguration = {
|
||||
mode: AppModeEnum
|
||||
inputs: Inputs
|
||||
modelConfig: ModelConfig
|
||||
}
|
||||
|
||||
const createDebugConfiguration = (overrides: Partial<DebugConfiguration> = {}): DebugConfiguration => ({
|
||||
mode: AppModeEnum.CHAT,
|
||||
inputs: {},
|
||||
modelConfig: createModelConfig(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createModelAndParameter = (overrides: Partial<ModelAndParameter> = {}): ModelAndParameter => ({
|
||||
id: `model-${++modelIdCounter}`,
|
||||
model: 'gpt-3.5-turbo',
|
||||
provider: 'openai',
|
||||
parameters: {},
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createProps = (overrides: Partial<DebugWithMultipleModelContextType> = {}): DebugWithMultipleModelContextType => ({
|
||||
multipleModelConfigs: [createModelAndParameter()],
|
||||
onMultipleModelConfigsChange: jest.fn(),
|
||||
onDebugWithMultipleModelChange: jest.fn(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const renderComponent = (props?: Partial<DebugWithMultipleModelContextType>) => {
|
||||
const mergedProps = createProps(props)
|
||||
return render(<DebugWithMultipleModel {...mergedProps} />)
|
||||
}
|
||||
|
||||
describe('DebugWithMultipleModel', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
capturedChatInputProps = null
|
||||
modelIdCounter = 0
|
||||
featureState = createFeatureState()
|
||||
mockUseFeaturesSelector.mockImplementation(selector => selector(featureState))
|
||||
mockUseEventEmitterContext.mockReturnValue({ eventEmitter: mockEventEmitter })
|
||||
mockUseAppStoreSelector.mockImplementation(selector => selector({ setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal }))
|
||||
mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration())
|
||||
})
|
||||
|
||||
describe('chat input rendering', () => {
|
||||
it('should render chat input in chat mode with transformed prompt variables and feature handler', () => {
|
||||
// Arrange
|
||||
const promptVariables: PromptVariableWithMeta[] = [
|
||||
{ key: 'city', name: 'City', type: 'string', required: true },
|
||||
{ key: 'audience', name: 'Audience', type: 'number' },
|
||||
{ key: 'hidden', name: 'Hidden', type: 'select', hide: true },
|
||||
{ key: 'api-only', name: 'API Only', type: 'api' },
|
||||
]
|
||||
const debugConfiguration = createDebugConfiguration({
|
||||
inputs: { audience: 'engineers' },
|
||||
modelConfig: createModelConfig(promptVariables),
|
||||
})
|
||||
mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration)
|
||||
|
||||
// Act
|
||||
renderComponent()
|
||||
fireEvent.click(screen.getByRole('button', { name: /feature/i }))
|
||||
|
||||
// Assert
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
expect(capturedChatInputProps?.inputs).toEqual({ audience: 'engineers' })
|
||||
expect(capturedChatInputProps?.inputsForm).toEqual([
|
||||
expect.objectContaining({ label: 'City', variable: 'city', hide: false, required: true }),
|
||||
expect.objectContaining({ label: 'Audience', variable: 'audience', hide: false, required: false }),
|
||||
expect.objectContaining({ label: 'Hidden', variable: 'hidden', hide: true, required: false }),
|
||||
])
|
||||
expect(capturedChatInputProps?.showFeatureBar).toBe(true)
|
||||
expect(capturedChatInputProps?.showFileUpload).toBe(false)
|
||||
expect(capturedChatInputProps?.speechToTextConfig).toEqual(featureState.features.speech2text)
|
||||
expect(capturedChatInputProps?.visionConfig).toEqual(featureState.features.file)
|
||||
expect(mockSetShowAppConfigureFeaturesModal).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
it('should render chat input in agent chat mode', () => {
|
||||
// Arrange
|
||||
mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({
|
||||
mode: AppModeEnum.AGENT_CHAT,
|
||||
}))
|
||||
|
||||
// Act
|
||||
renderComponent()
|
||||
|
||||
// Assert
|
||||
expect(screen.getByTestId('chat-input-area')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide chat input when not in chat mode', () => {
|
||||
// Arrange
|
||||
mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({
|
||||
mode: AppModeEnum.COMPLETION,
|
||||
}))
|
||||
const multipleModelConfigs = [createModelAndParameter()]
|
||||
|
||||
// Act
|
||||
renderComponent({ multipleModelConfigs })
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByTestId('chat-input-area')).not.toBeInTheDocument()
|
||||
expect(screen.getAllByTestId('debug-item')).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('sending flow', () => {
|
||||
it('should emit chat event when allowed to send', () => {
|
||||
// Arrange
|
||||
const checkCanSend = jest.fn(() => true)
|
||||
const multipleModelConfigs = [createModelAndParameter(), createModelAndParameter()]
|
||||
renderComponent({ multipleModelConfigs, checkCanSend })
|
||||
|
||||
// Act
|
||||
fireEvent.click(screen.getByRole('button', { name: /send/i }))
|
||||
|
||||
// Assert
|
||||
expect(checkCanSend).toHaveBeenCalled()
|
||||
expect(mockEventEmitter.emit).toHaveBeenCalledWith({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: {
|
||||
message: 'test message',
|
||||
files: mockFiles,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should emit when no checkCanSend is provided', () => {
|
||||
renderComponent()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /send/i }))
|
||||
|
||||
expect(mockEventEmitter.emit).toHaveBeenCalledWith({
|
||||
type: APP_CHAT_WITH_MULTIPLE_MODEL,
|
||||
payload: {
|
||||
message: 'test message',
|
||||
files: mockFiles,
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should block sending when checkCanSend returns false', () => {
|
||||
// Arrange
|
||||
const checkCanSend = jest.fn(() => false)
|
||||
renderComponent({ checkCanSend })
|
||||
|
||||
// Act
|
||||
fireEvent.click(screen.getByRole('button', { name: /send/i }))
|
||||
|
||||
// Assert
|
||||
expect(checkCanSend).toHaveBeenCalled()
|
||||
expect(mockEventEmitter.emit).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should tolerate missing event emitter without throwing', () => {
|
||||
mockUseEventEmitterContext.mockReturnValue({ eventEmitter: null })
|
||||
renderComponent()
|
||||
|
||||
expect(() => fireEvent.click(screen.getByRole('button', { name: /send/i }))).not.toThrow()
|
||||
expect(mockEventEmitter.emit).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('layout sizing and positioning', () => {
|
||||
const expectItemLayout = (
|
||||
element: HTMLElement,
|
||||
expectation: {
|
||||
width?: string
|
||||
height?: string
|
||||
transform: string
|
||||
classes?: string[]
|
||||
},
|
||||
) => {
|
||||
if (expectation.width !== undefined)
|
||||
expect(element.style.width).toBe(expectation.width)
|
||||
else
|
||||
expect(element.style.width).toBe('')
|
||||
|
||||
if (expectation.height !== undefined)
|
||||
expect(element.style.height).toBe(expectation.height)
|
||||
else
|
||||
expect(element.style.height).toBe('')
|
||||
|
||||
expect(element.style.transform).toBe(expectation.transform)
|
||||
expectation.classes?.forEach(cls => expect(element).toHaveClass(cls))
|
||||
}
|
||||
|
||||
it('should arrange items in two-column layout for two models', () => {
|
||||
// Arrange
|
||||
const multipleModelConfigs = [createModelAndParameter(), createModelAndParameter()]
|
||||
|
||||
// Act
|
||||
renderComponent({ multipleModelConfigs })
|
||||
const items = screen.getAllByTestId('debug-item')
|
||||
|
||||
// Assert
|
||||
expect(items).toHaveLength(2)
|
||||
expectItemLayout(items[0], {
|
||||
width: 'calc(50% - 4px - 24px)',
|
||||
height: '100%',
|
||||
transform: 'translateX(0) translateY(0)',
|
||||
classes: ['mr-2'],
|
||||
})
|
||||
expectItemLayout(items[1], {
|
||||
width: 'calc(50% - 4px - 24px)',
|
||||
height: '100%',
|
||||
transform: 'translateX(calc(100% + 8px)) translateY(0)',
|
||||
classes: [],
|
||||
})
|
||||
})
|
||||
|
||||
it('should arrange items in thirds for three models', () => {
|
||||
// Arrange
|
||||
const multipleModelConfigs = [createModelAndParameter(), createModelAndParameter(), createModelAndParameter()]
|
||||
|
||||
// Act
|
||||
renderComponent({ multipleModelConfigs })
|
||||
const items = screen.getAllByTestId('debug-item')
|
||||
|
||||
// Assert
|
||||
expect(items).toHaveLength(3)
|
||||
expectItemLayout(items[0], {
|
||||
width: 'calc(33.3% - 5.33px - 16px)',
|
||||
height: '100%',
|
||||
transform: 'translateX(0) translateY(0)',
|
||||
classes: ['mr-2'],
|
||||
})
|
||||
expectItemLayout(items[1], {
|
||||
width: 'calc(33.3% - 5.33px - 16px)',
|
||||
height: '100%',
|
||||
transform: 'translateX(calc(100% + 8px)) translateY(0)',
|
||||
classes: ['mr-2'],
|
||||
})
|
||||
expectItemLayout(items[2], {
|
||||
width: 'calc(33.3% - 5.33px - 16px)',
|
||||
height: '100%',
|
||||
transform: 'translateX(calc(200% + 16px)) translateY(0)',
|
||||
classes: [],
|
||||
})
|
||||
})
|
||||
|
||||
it('should position items on a grid for four models', () => {
|
||||
// Arrange
|
||||
const multipleModelConfigs = [
|
||||
createModelAndParameter(),
|
||||
createModelAndParameter(),
|
||||
createModelAndParameter(),
|
||||
createModelAndParameter(),
|
||||
]
|
||||
|
||||
// Act
|
||||
renderComponent({ multipleModelConfigs })
|
||||
const items = screen.getAllByTestId('debug-item')
|
||||
|
||||
// Assert
|
||||
expect(items).toHaveLength(4)
|
||||
expectItemLayout(items[0], {
|
||||
width: 'calc(50% - 4px - 24px)',
|
||||
height: 'calc(50% - 4px)',
|
||||
transform: 'translateX(0) translateY(0)',
|
||||
classes: ['mr-2', 'mb-2'],
|
||||
})
|
||||
expectItemLayout(items[1], {
|
||||
width: 'calc(50% - 4px - 24px)',
|
||||
height: 'calc(50% - 4px)',
|
||||
transform: 'translateX(calc(100% + 8px)) translateY(0)',
|
||||
classes: ['mb-2'],
|
||||
})
|
||||
expectItemLayout(items[2], {
|
||||
width: 'calc(50% - 4px - 24px)',
|
||||
height: 'calc(50% - 4px)',
|
||||
transform: 'translateX(0) translateY(calc(100% + 8px))',
|
||||
classes: ['mr-2'],
|
||||
})
|
||||
expectItemLayout(items[3], {
|
||||
width: 'calc(50% - 4px - 24px)',
|
||||
height: 'calc(50% - 4px)',
|
||||
transform: 'translateX(calc(100% + 8px)) translateY(calc(100% + 8px))',
|
||||
classes: [],
|
||||
})
|
||||
})
|
||||
|
||||
it('should fall back to single column layout when only one model is provided', () => {
|
||||
// Arrange
|
||||
const multipleModelConfigs = [createModelAndParameter()]
|
||||
|
||||
// Act
|
||||
renderComponent({ multipleModelConfigs })
|
||||
const item = screen.getByTestId('debug-item')
|
||||
|
||||
// Assert
|
||||
expectItemLayout(item, {
|
||||
transform: 'translateX(0) translateY(0)',
|
||||
classes: [],
|
||||
})
|
||||
})
|
||||
|
||||
it('should set scroll area height for chat modes', () => {
|
||||
const { container } = renderComponent()
|
||||
const scrollArea = container.querySelector('.relative.mb-3.grow.overflow-auto.px-6') as HTMLElement
|
||||
expect(scrollArea).toBeInTheDocument()
|
||||
expect(scrollArea.style.height).toBe('calc(100% - 60px)')
|
||||
})
|
||||
|
||||
it('should set full height when chat input is hidden', () => {
|
||||
mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({
|
||||
mode: AppModeEnum.COMPLETION,
|
||||
}))
|
||||
|
||||
const { container } = renderComponent()
|
||||
const scrollArea = container.querySelector('.relative.mb-3.grow.overflow-auto.px-6') as HTMLElement
|
||||
expect(scrollArea.style.height).toBe('100%')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -2,6 +2,7 @@
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import useSWR from 'swr'
|
||||
import dayjs from 'dayjs'
|
||||
import { RiCalendarLine } from '@remixicon/react'
|
||||
import quarterOfYear from 'dayjs/plugin/quarterOfYear'
|
||||
@ -9,7 +10,7 @@ import type { QueryParam } from './index'
|
||||
import Chip from '@/app/components/base/chip'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Sort from '@/app/components/base/sort'
|
||||
import { useAnnotationsCount } from '@/service/use-log'
|
||||
import { fetchAnnotationsCount } from '@/service/log'
|
||||
dayjs.extend(quarterOfYear)
|
||||
|
||||
const today = dayjs()
|
||||
@ -34,7 +35,7 @@ type IFilterProps = {
|
||||
}
|
||||
|
||||
const Filter: FC<IFilterProps> = ({ isChatMode, appId, queryParams, setQueryParams }: IFilterProps) => {
|
||||
const { data } = useAnnotationsCount(appId)
|
||||
const { data } = useSWR({ url: `/apps/${appId}/annotations/count` }, fetchAnnotationsCount)
|
||||
const { t } = useTranslation()
|
||||
if (!data)
|
||||
return null
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import React, { useCallback, useEffect, useState } from 'react'
|
||||
import useSWR from 'swr'
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { omit } from 'lodash-es'
|
||||
import dayjs from 'dayjs'
|
||||
@ -11,11 +12,10 @@ import Filter, { TIME_PERIOD_MAPPING } from './filter'
|
||||
import EmptyElement from './empty-element'
|
||||
import Pagination from '@/app/components/base/pagination'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { useChatConversations, useCompletionConversations } from '@/service/use-log'
|
||||
import { fetchChatConversations, fetchCompletionConversations } from '@/service/log'
|
||||
import { APP_PAGE_LIMIT } from '@/config'
|
||||
import type { App } from '@/types/app'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import type { ChatConversationsRequest, CompletionConversationsRequest } from '@/models/log'
|
||||
export type ILogsProps = {
|
||||
appDetail: App
|
||||
}
|
||||
@ -71,43 +71,37 @@ const Logs: FC<ILogsProps> = ({ appDetail }) => {
|
||||
|
||||
// Get the app type first
|
||||
const isChatMode = appDetail.mode !== AppModeEnum.COMPLETION
|
||||
const { sort_by } = debouncedQueryParams
|
||||
|
||||
const completionQuery = useMemo<CompletionConversationsRequest & { sort_by?: string }>(() => ({
|
||||
const query = {
|
||||
page: currPage + 1,
|
||||
limit,
|
||||
keyword: debouncedQueryParams.keyword ?? '',
|
||||
annotation_status: debouncedQueryParams.annotation_status ?? 'all',
|
||||
start: debouncedQueryParams.period !== '9'
|
||||
? dayjs().subtract(TIME_PERIOD_MAPPING[debouncedQueryParams.period].value, 'day').startOf('day').format('YYYY-MM-DD HH:mm')
|
||||
: '',
|
||||
end: debouncedQueryParams.period !== '9'
|
||||
? dayjs().endOf('day').format('YYYY-MM-DD HH:mm')
|
||||
: '',
|
||||
...omit(debouncedQueryParams, ['period', 'sort_by', 'keyword', 'annotation_status']),
|
||||
}), [currPage, debouncedQueryParams, limit])
|
||||
|
||||
const chatQuery = useMemo<ChatConversationsRequest & { sort_by?: string }>(() => ({
|
||||
...completionQuery,
|
||||
sort_by,
|
||||
message_count: 0,
|
||||
}), [completionQuery, sort_by])
|
||||
...((debouncedQueryParams.period !== '9')
|
||||
? {
|
||||
start: dayjs().subtract(TIME_PERIOD_MAPPING[debouncedQueryParams.period].value, 'day').startOf('day').format('YYYY-MM-DD HH:mm'),
|
||||
end: dayjs().endOf('day').format('YYYY-MM-DD HH:mm'),
|
||||
}
|
||||
: {}),
|
||||
...(isChatMode ? { sort_by: debouncedQueryParams.sort_by } : {}),
|
||||
...omit(debouncedQueryParams, ['period']),
|
||||
}
|
||||
|
||||
// When the details are obtained, proceed to the next request
|
||||
const { data: chatConversations, refetch: refetchChatList } = useChatConversations(appDetail.id, chatQuery, isChatMode)
|
||||
const { data: chatConversations, mutate: mutateChatList } = useSWR(() => isChatMode
|
||||
? {
|
||||
url: `/apps/${appDetail.id}/chat-conversations`,
|
||||
params: query,
|
||||
}
|
||||
: null, fetchChatConversations)
|
||||
|
||||
const { data: completionConversations, refetch: refetchCompletionList } = useCompletionConversations(appDetail.id, completionQuery, !isChatMode)
|
||||
const { data: completionConversations, mutate: mutateCompletionList } = useSWR(() => !isChatMode
|
||||
? {
|
||||
url: `/apps/${appDetail.id}/completion-conversations`,
|
||||
params: query,
|
||||
}
|
||||
: null, fetchCompletionConversations)
|
||||
|
||||
const total = isChatMode ? chatConversations?.total : completionConversations?.total
|
||||
|
||||
const handleRefreshList = useCallback(() => {
|
||||
if (isChatMode) {
|
||||
void refetchChatList()
|
||||
return
|
||||
}
|
||||
void refetchCompletionList()
|
||||
}, [isChatMode, refetchChatList, refetchCompletionList])
|
||||
|
||||
const handleQueryParamsChange = useCallback((next: QueryParam) => {
|
||||
setCurrPage(0)
|
||||
setQueryParams(next)
|
||||
@ -130,13 +124,12 @@ const Logs: FC<ILogsProps> = ({ appDetail }) => {
|
||||
<p className='system-sm-regular shrink-0 text-text-tertiary'>{t('appLog.description')}</p>
|
||||
<div className='flex max-h-[calc(100%-16px)] flex-1 grow flex-col py-4'>
|
||||
<Filter isChatMode={isChatMode} appId={appDetail.id} queryParams={queryParams} setQueryParams={handleQueryParamsChange} />
|
||||
{(() => {
|
||||
if (total === undefined)
|
||||
return <Loading type='app' />
|
||||
if (total > 0)
|
||||
return <List logs={isChatMode ? chatConversations : completionConversations} appDetail={appDetail} onRefresh={handleRefreshList} />
|
||||
return <EmptyElement appDetail={appDetail} />
|
||||
})()}
|
||||
{total === undefined
|
||||
? <Loading type='app' />
|
||||
: total > 0
|
||||
? <List logs={isChatMode ? chatConversations : completionConversations} appDetail={appDetail} onRefresh={isChatMode ? mutateChatList : mutateCompletionList} />
|
||||
: <EmptyElement appDetail={appDetail} />
|
||||
}
|
||||
{/* Show Pagination only if the total is more than the limit */}
|
||||
{(total && total > APP_PAGE_LIMIT)
|
||||
? <Pagination
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import useSWR from 'swr'
|
||||
import {
|
||||
HandThumbDownIcon,
|
||||
HandThumbUpIcon,
|
||||
@ -17,7 +18,7 @@ import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
import type { ChatItemInTree } from '../../base/chat/types'
|
||||
import Indicator from '../../header/indicator'
|
||||
import VarPanel from './var-panel'
|
||||
import type { FeedbackFunc, IChatItem, SubmitAnnotationFunc } from '@/app/components/base/chat/chat/type'
|
||||
import type { FeedbackFunc, FeedbackType, IChatItem, SubmitAnnotationFunc } from '@/app/components/base/chat/chat/type'
|
||||
import type { Annotation, ChatConversationGeneralDetail, ChatConversationsResponse, ChatMessage, ChatMessagesRequest, CompletionConversationGeneralDetail, CompletionConversationsResponse, LogAnnotation } from '@/models/log'
|
||||
import { type App, AppModeEnum } from '@/types/app'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
@ -25,8 +26,7 @@ import Loading from '@/app/components/base/loading'
|
||||
import Drawer from '@/app/components/base/drawer'
|
||||
import Chat from '@/app/components/base/chat/chat'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { fetchChatMessages } from '@/service/log'
|
||||
import { useChatConversationDetail, useCompletionConversationDetail, useUpdateLogMessageAnnotation, useUpdateLogMessageFeedback } from '@/service/use-log'
|
||||
import { fetchChatConversationDetail, fetchChatMessages, fetchCompletionConversationDetail, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log'
|
||||
import ModelInfo from '@/app/components/app/log/model-info'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import TextGeneration from '@/app/components/app/text-generate/item'
|
||||
@ -199,39 +199,6 @@ type IDetailPanel = {
|
||||
onSubmitAnnotation: SubmitAnnotationFunc
|
||||
}
|
||||
|
||||
const useConversationDetailActions = (appId: string | undefined, detailQueryKey: any) => {
|
||||
const { notify } = useContext(ToastContext)
|
||||
const { t } = useTranslation()
|
||||
const { mutateAsync: submitFeedback } = useUpdateLogMessageFeedback(appId, detailQueryKey)
|
||||
const { mutateAsync: submitAnnotation } = useUpdateLogMessageAnnotation(appId, detailQueryKey)
|
||||
|
||||
const handleFeedback = useCallback<FeedbackFunc>(async (mid, { rating, content }) => {
|
||||
try {
|
||||
await submitFeedback({ mid, rating, content })
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
return true
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||
return false
|
||||
}
|
||||
}, [notify, submitFeedback, t])
|
||||
|
||||
const handleAnnotation = useCallback<SubmitAnnotationFunc>(async (mid, value) => {
|
||||
try {
|
||||
await submitAnnotation({ mid, value })
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
return true
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||
return false
|
||||
}
|
||||
}, [notify, submitAnnotation, t])
|
||||
|
||||
return { handleFeedback, handleAnnotation }
|
||||
}
|
||||
|
||||
function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
const MIN_ITEMS_FOR_SCROLL_LOADING = 8
|
||||
const SCROLL_THRESHOLD_PX = 50
|
||||
@ -402,7 +369,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||
return false
|
||||
}
|
||||
}, [allChatItems, appDetail?.id, notify, t])
|
||||
}, [allChatItems, appDetail?.id, t])
|
||||
|
||||
const fetchInitiated = useRef(false)
|
||||
|
||||
@ -549,7 +516,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [allChatItems, detail?.model_config?.configs?.introduction, detail.id, hasMore, isLoading, timezone, t, appDetail])
|
||||
}, [allChatItems, detail.id, hasMore, isLoading, timezone, t, appDetail])
|
||||
|
||||
useEffect(() => {
|
||||
const scrollableDiv = document.getElementById('scrollableDiv')
|
||||
@ -844,8 +811,39 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
*/
|
||||
const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: string }> = ({ appId, conversationId }) => {
|
||||
// Text Generator App Session Details Including Message List
|
||||
const { data: conversationDetail, queryKey: detailQueryKey } = useCompletionConversationDetail(appId, conversationId)
|
||||
const { handleFeedback, handleAnnotation } = useConversationDetailActions(appId, detailQueryKey)
|
||||
const detailParams = ({ url: `/apps/${appId}/completion-conversations/${conversationId}` })
|
||||
const { data: conversationDetail, mutate: conversationDetailMutate } = useSWR(() => (appId && conversationId) ? detailParams : null, fetchCompletionConversationDetail)
|
||||
const { notify } = useContext(ToastContext)
|
||||
const { t } = useTranslation()
|
||||
|
||||
const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise<boolean> => {
|
||||
try {
|
||||
await updateLogMessageFeedbacks({
|
||||
url: `/apps/${appId}/feedbacks`,
|
||||
body: { message_id: mid, rating, content: content ?? undefined },
|
||||
})
|
||||
conversationDetailMutate()
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
return true
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const handleAnnotation = async (mid: string, value: string): Promise<boolean> => {
|
||||
try {
|
||||
await updateLogMessageAnnotations({ url: `/apps/${appId}/annotations`, body: { message_id: mid, content: value } })
|
||||
conversationDetailMutate()
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
return true
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (!conversationDetail)
|
||||
return null
|
||||
@ -861,8 +859,37 @@ const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: st
|
||||
* Chat App Conversation Detail Component
|
||||
*/
|
||||
const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string }> = ({ appId, conversationId }) => {
|
||||
const { data: conversationDetail, queryKey: detailQueryKey } = useChatConversationDetail(appId, conversationId)
|
||||
const { handleFeedback, handleAnnotation } = useConversationDetailActions(appId, detailQueryKey)
|
||||
const detailParams = { url: `/apps/${appId}/chat-conversations/${conversationId}` }
|
||||
const { data: conversationDetail } = useSWR(() => (appId && conversationId) ? detailParams : null, fetchChatConversationDetail)
|
||||
const { notify } = useContext(ToastContext)
|
||||
const { t } = useTranslation()
|
||||
|
||||
const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise<boolean> => {
|
||||
try {
|
||||
await updateLogMessageFeedbacks({
|
||||
url: `/apps/${appId}/feedbacks`,
|
||||
body: { message_id: mid, rating, content: content ?? undefined },
|
||||
})
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
return true
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const handleAnnotation = async (mid: string, value: string): Promise<boolean> => {
|
||||
try {
|
||||
await updateLogMessageAnnotations({ url: `/apps/${appId}/annotations`, body: { message_id: mid, content: value } })
|
||||
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
|
||||
return true
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (!conversationDetail)
|
||||
return null
|
||||
@ -935,12 +962,14 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
router.push(buildUrlWithConversation(log.id), { scroll: false })
|
||||
}, [buildUrlWithConversation, conversationIdInUrl, currentConversation, router, showDrawer])
|
||||
|
||||
const currentConversationId = currentConversation?.id
|
||||
|
||||
useEffect(() => {
|
||||
if (!conversationIdInUrl) {
|
||||
if (pendingConversationIdRef.current)
|
||||
return
|
||||
|
||||
if (showDrawer || currentConversation?.id) {
|
||||
if (showDrawer || currentConversationId) {
|
||||
setShowDrawer(false)
|
||||
setCurrentConversation(undefined)
|
||||
}
|
||||
|
||||
@ -27,6 +27,7 @@ import { fetchWorkflowDraft } from '@/service/workflow'
|
||||
import { fetchInstalledAppList } from '@/service/explore'
|
||||
import { AppTypeIcon } from '@/app/components/app/type-selector'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { formatTime } from '@/utils/time'
|
||||
@ -64,6 +65,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
||||
const { isCurrentWorkspaceEditor } = useAppContext()
|
||||
const { onPlanInfoChanged } = useProviderContext()
|
||||
const { push } = useRouter()
|
||||
const openAsyncWindow = useAsyncWindowOpen()
|
||||
|
||||
const [showEditModal, setShowEditModal] = useState(false)
|
||||
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
|
||||
@ -247,11 +249,16 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
||||
props.onClick?.()
|
||||
e.preventDefault()
|
||||
try {
|
||||
const { installed_apps }: any = await fetchInstalledAppList(app.id) || {}
|
||||
if (installed_apps?.length > 0)
|
||||
window.open(`${basePath}/explore/installed/${installed_apps[0].id}`, '_blank')
|
||||
else
|
||||
await openAsyncWindow(async () => {
|
||||
const { installed_apps }: any = await fetchInstalledAppList(app.id) || {}
|
||||
if (installed_apps?.length > 0)
|
||||
return `${basePath}/explore/installed/${installed_apps[0].id}`
|
||||
throw new Error('No app found in Explore')
|
||||
}, {
|
||||
onError: (err) => {
|
||||
Toast.notify({ type: 'error', message: `${err.message || err}` })
|
||||
},
|
||||
})
|
||||
}
|
||||
catch (e: any) {
|
||||
Toast.notify({ type: 'error', message: `${e.message || e}` })
|
||||
|
||||
@ -4,24 +4,27 @@ import type { FC } from 'react'
|
||||
import React, { useEffect } from 'react'
|
||||
import * as amplitude from '@amplitude/analytics-browser'
|
||||
import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
import { AMPLITUDE_API_KEY, IS_CLOUD_EDITION } from '@/config'
|
||||
|
||||
export type IAmplitudeProps = {
|
||||
apiKey?: string
|
||||
sessionReplaySampleRate?: number
|
||||
}
|
||||
|
||||
// Check if Amplitude should be enabled
|
||||
export const isAmplitudeEnabled = () => {
|
||||
return IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY
|
||||
}
|
||||
|
||||
const AmplitudeProvider: FC<IAmplitudeProps> = ({
|
||||
apiKey = process.env.NEXT_PUBLIC_AMPLITUDE_API_KEY ?? '',
|
||||
sessionReplaySampleRate = 1,
|
||||
}) => {
|
||||
useEffect(() => {
|
||||
// Only enable in Saas edition
|
||||
if (!IS_CLOUD_EDITION)
|
||||
// Only enable in Saas edition with valid API key
|
||||
if (!isAmplitudeEnabled())
|
||||
return
|
||||
|
||||
// Initialize Amplitude
|
||||
amplitude.init(apiKey, {
|
||||
amplitude.init(AMPLITUDE_API_KEY, {
|
||||
defaultTracking: {
|
||||
sessions: true,
|
||||
pageViews: true,
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
export { default } from './AmplitudeProvider'
|
||||
export { default, isAmplitudeEnabled } from './AmplitudeProvider'
|
||||
export { resetUser, setUserId, setUserProperties, trackEvent } from './utils'
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import * as amplitude from '@amplitude/analytics-browser'
|
||||
import { isAmplitudeEnabled } from './AmplitudeProvider'
|
||||
|
||||
/**
|
||||
* Track custom event
|
||||
@ -6,6 +7,8 @@ import * as amplitude from '@amplitude/analytics-browser'
|
||||
* @param eventProperties Event properties (optional)
|
||||
*/
|
||||
export const trackEvent = (eventName: string, eventProperties?: Record<string, any>) => {
|
||||
if (!isAmplitudeEnabled())
|
||||
return
|
||||
amplitude.track(eventName, eventProperties)
|
||||
}
|
||||
|
||||
@ -14,6 +17,8 @@ export const trackEvent = (eventName: string, eventProperties?: Record<string, a
|
||||
* @param userId User ID
|
||||
*/
|
||||
export const setUserId = (userId: string) => {
|
||||
if (!isAmplitudeEnabled())
|
||||
return
|
||||
amplitude.setUserId(userId)
|
||||
}
|
||||
|
||||
@ -22,6 +27,8 @@ export const setUserId = (userId: string) => {
|
||||
* @param properties User properties
|
||||
*/
|
||||
export const setUserProperties = (properties: Record<string, any>) => {
|
||||
if (!isAmplitudeEnabled())
|
||||
return
|
||||
const identifyEvent = new amplitude.Identify()
|
||||
Object.entries(properties).forEach(([key, value]) => {
|
||||
identifyEvent.set(key, value)
|
||||
@ -33,5 +40,7 @@ export const setUserProperties = (properties: Record<string, any>) => {
|
||||
* Reset user (e.g., when user logs out)
|
||||
*/
|
||||
export const resetUser = () => {
|
||||
if (!isAmplitudeEnabled())
|
||||
return
|
||||
amplitude.reset()
|
||||
}
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M19 6C19 5.44771 18.5523 5 18 5H6C5.44771 5 5 5.44771 5 6V18C5 18.5523 5.44771 19 6 19H18C18.5523 19 19 18.5523 19 18V6ZM9.73926 13.1533C10.0706 12.7115 10.6978 12.6218 11.1396 12.9531C11.5815 13.2845 11.6712 13.9117 11.3398 14.3535L9.46777 16.8486C9.14935 17.2732 8.55487 17.3754 8.11328 17.0811L6.98828 16.3311C6.52878 16.0247 6.40465 15.4039 6.71094 14.9443C7.01729 14.4848 7.63813 14.3606 8.09766 14.667L8.43457 14.8916L9.73926 13.1533ZM16 14C16.5523 14 17 14.4477 17 15C17 15.5523 16.5523 16 16 16H14C13.4477 16 13 15.5523 13 15C13 14.4477 13.4477 14 14 14H16ZM9.73926 7.15234C10.0706 6.71052 10.6978 6.62079 11.1396 6.95215C11.5815 7.28352 11.6712 7.91071 11.3398 8.35254L9.46777 10.8477C9.14936 11.2722 8.55487 11.3744 8.11328 11.0801L6.98828 10.3301C6.52884 10.0238 6.40476 9.40286 6.71094 8.94336C7.0173 8.48384 7.63814 8.35965 8.09766 8.66602L8.43457 8.89062L9.73926 7.15234ZM16.0576 8C16.6099 8 17.0576 8.44772 17.0576 9C17.0576 9.55228 16.6099 10 16.0576 10H14.0576C13.5055 9.99985 13.0576 9.55219 13.0576 9C13.0576 8.44781 13.5055 8.00015 14.0576 8H16.0576ZM21 18C21 19.6569 19.6569 21 18 21H6C4.34315 21 3 19.6569 3 18V6C3 4.34315 4.34315 3 6 3H18C19.6569 3 21 4.34315 21 6V18Z" fill="white"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
@ -0,0 +1,26 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "24",
|
||||
"height": "24",
|
||||
"viewBox": "0 0 24 24",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M19 6C19 5.44771 18.5523 5 18 5H6C5.44771 5 5 5.44771 5 6V18C5 18.5523 5.44771 19 6 19H18C18.5523 19 19 18.5523 19 18V6ZM9.73926 13.1533C10.0706 12.7115 10.6978 12.6218 11.1396 12.9531C11.5815 13.2845 11.6712 13.9117 11.3398 14.3535L9.46777 16.8486C9.14935 17.2732 8.55487 17.3754 8.11328 17.0811L6.98828 16.3311C6.52878 16.0247 6.40465 15.4039 6.71094 14.9443C7.01729 14.4848 7.63813 14.3606 8.09766 14.667L8.43457 14.8916L9.73926 13.1533ZM16 14C16.5523 14 17 14.4477 17 15C17 15.5523 16.5523 16 16 16H14C13.4477 16 13 15.5523 13 15C13 14.4477 13.4477 14 14 14H16ZM9.73926 7.15234C10.0706 6.71052 10.6978 6.62079 11.1396 6.95215C11.5815 7.28352 11.6712 7.91071 11.3398 8.35254L9.46777 10.8477C9.14936 11.2722 8.55487 11.3744 8.11328 11.0801L6.98828 10.3301C6.52884 10.0238 6.40476 9.40286 6.71094 8.94336C7.0173 8.48384 7.63814 8.35965 8.09766 8.66602L8.43457 8.89062L9.73926 7.15234ZM16.0576 8C16.6099 8 17.0576 8.44772 17.0576 9C17.0576 9.55228 16.6099 10 16.0576 10H14.0576C13.5055 9.99985 13.0576 9.55219 13.0576 9C13.0576 8.44781 13.5055 8.00015 14.0576 8H16.0576ZM21 18C21 19.6569 19.6569 21 18 21H6C4.34315 21 3 19.6569 3 18V6C3 4.34315 4.34315 3 6 3H18C19.6569 3 21 4.34315 21 6V18Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "SquareChecklist"
|
||||
}
|
||||
@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './SquareChecklist.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>;
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'SquareChecklist'
|
||||
|
||||
export default Icon
|
||||
@ -6,3 +6,4 @@ export { default as Mcp } from './Mcp'
|
||||
export { default as NoToolPlaceholder } from './NoToolPlaceholder'
|
||||
export { default as Openai } from './Openai'
|
||||
export { default as ReplayLine } from './ReplayLine'
|
||||
export { default as SquareChecklist } from './SquareChecklist'
|
||||
|
||||
@ -110,7 +110,7 @@ const NotionPageSelector = ({
|
||||
setCurrentCredential(credential)
|
||||
onSelect([]) // Clear selected pages when changing credential
|
||||
onSelectCredential?.(credential.credentialId)
|
||||
}, [invalidPreImportNotionPages, onSelect, onSelectCredential])
|
||||
}, [datasetId, invalidPreImportNotionPages, notionCredentials, onSelect, onSelectCredential])
|
||||
|
||||
const handleSelectPages = useCallback((newSelectedPagesId: Set<string>) => {
|
||||
const selectedPages = Array.from(newSelectedPagesId).map(pageId => pagesMapAndSelectedPagesId[0][pageId])
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
'use client'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import React, { Fragment, useMemo } from 'react'
|
||||
import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react'
|
||||
import { RiArrowDownSLine } from '@remixicon/react'
|
||||
import NotionIcon from '../../notion-icon'
|
||||
import { CredentialIcon } from '@/app/components/datasets/common/credential-icon'
|
||||
|
||||
export type NotionCredential = {
|
||||
credentialId: string
|
||||
@ -23,14 +22,10 @@ const CredentialSelector = ({
|
||||
items,
|
||||
onSelect,
|
||||
}: CredentialSelectorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const currentCredential = items.find(item => item.credentialId === value)!
|
||||
|
||||
const getDisplayName = (item: NotionCredential) => {
|
||||
return item.workspaceName || t('datasetPipeline.credentialSelector.name', {
|
||||
credentialName: item.credentialName,
|
||||
pluginName: 'Notion',
|
||||
})
|
||||
return item.workspaceName || item.credentialName
|
||||
}
|
||||
|
||||
const currentDisplayName = useMemo(() => {
|
||||
@ -43,10 +38,11 @@ const CredentialSelector = ({
|
||||
({ open }) => (
|
||||
<>
|
||||
<MenuButton className={`flex h-7 items-center justify-center rounded-md p-1 pr-2 hover:bg-state-base-hover ${open && 'bg-state-base-hover'} cursor-pointer`}>
|
||||
<NotionIcon
|
||||
<CredentialIcon
|
||||
className='mr-2'
|
||||
src={currentCredential?.workspaceIcon}
|
||||
avatarUrl={currentCredential?.workspaceIcon}
|
||||
name={currentDisplayName}
|
||||
size={20}
|
||||
/>
|
||||
<div
|
||||
className='mr-1 w-[90px] truncate text-left text-sm font-medium text-text-secondary'
|
||||
@ -80,10 +76,11 @@ const CredentialSelector = ({
|
||||
className='flex h-9 cursor-pointer items-center rounded-lg px-3 hover:bg-state-base-hover'
|
||||
onClick={() => onSelect(item.credentialId)}
|
||||
>
|
||||
<NotionIcon
|
||||
<CredentialIcon
|
||||
className='mr-2 shrink-0'
|
||||
src={item.workspaceIcon}
|
||||
avatarUrl={item.workspaceIcon}
|
||||
name={displayName}
|
||||
size={20}
|
||||
/>
|
||||
<div
|
||||
className='system-sm-medium mr-2 grow truncate text-text-secondary'
|
||||
|
||||
@ -18,6 +18,7 @@ type PageSelectorProps = {
|
||||
canPreview?: boolean
|
||||
previewPageId?: string
|
||||
onPreview?: (selectedPageId: string) => void
|
||||
isMultipleChoice?: boolean
|
||||
}
|
||||
type NotionPageTreeItem = {
|
||||
children: Set<string>
|
||||
@ -139,8 +140,6 @@ const ItemComponent = ({ index, style, data }: ListChildComponentProps<{
|
||||
checked={checkedIds.has(current.page_id)}
|
||||
disabled={disabled}
|
||||
onCheck={() => {
|
||||
if (disabled)
|
||||
return
|
||||
handleCheck(index)
|
||||
}}
|
||||
/>
|
||||
|
||||
@ -12,6 +12,7 @@ const PremiumBadgeVariants = cva(
|
||||
size: {
|
||||
s: 'premium-badge-s',
|
||||
m: 'premium-badge-m',
|
||||
custom: '',
|
||||
},
|
||||
color: {
|
||||
blue: 'premium-badge-blue',
|
||||
@ -33,7 +34,7 @@ const PremiumBadgeVariants = cva(
|
||||
)
|
||||
|
||||
type PremiumBadgeProps = {
|
||||
size?: 's' | 'm'
|
||||
size?: 's' | 'm' | 'custom'
|
||||
color?: 'blue' | 'indigo' | 'gray' | 'orange'
|
||||
allowHover?: boolean
|
||||
styleCss?: CSSProperties
|
||||
|
||||
@ -9,33 +9,28 @@ import PlanComp from '../plan'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useBillingUrl } from '@/service/use-billing'
|
||||
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
|
||||
const Billing: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const { enableBilling } = useProviderContext()
|
||||
const { data: billingUrl, isFetching, refetch } = useBillingUrl(enableBilling && isCurrentWorkspaceManager)
|
||||
const openAsyncWindow = useAsyncWindowOpen()
|
||||
|
||||
const handleOpenBilling = async () => {
|
||||
// Open synchronously to preserve user gesture for popup blockers
|
||||
if (billingUrl) {
|
||||
window.open(billingUrl, '_blank', 'noopener,noreferrer')
|
||||
return
|
||||
}
|
||||
|
||||
const newWindow = window.open('', '_blank', 'noopener,noreferrer')
|
||||
try {
|
||||
await openAsyncWindow(async () => {
|
||||
const url = (await refetch()).data
|
||||
if (url && newWindow) {
|
||||
newWindow.location.href = url
|
||||
return
|
||||
}
|
||||
}
|
||||
catch (err) {
|
||||
console.error('Failed to fetch billing url', err)
|
||||
}
|
||||
// Close the placeholder window if we failed to fetch the URL
|
||||
newWindow?.close()
|
||||
if (url)
|
||||
return url
|
||||
return null
|
||||
}, {
|
||||
immediateUrl: billingUrl,
|
||||
features: 'noopener,noreferrer',
|
||||
onError: (err) => {
|
||||
console.error('Failed to fetch billing url', err)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
118
web/app/components/billing/plan-upgrade-modal/index.spec.tsx
Normal file
118
web/app/components/billing/plan-upgrade-modal/index.spec.tsx
Normal file
@ -0,0 +1,118 @@
|
||||
import React from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import PlanUpgradeModal from './index'
|
||||
|
||||
const mockSetShowPricingModal = jest.fn()
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/base/modal', () => {
|
||||
const MockModal = ({ isShow, children }: { isShow: boolean; children: React.ReactNode }) => (
|
||||
isShow ? <div data-testid="plan-upgrade-modal">{children}</div> : null
|
||||
)
|
||||
return {
|
||||
__esModule: true,
|
||||
default: MockModal,
|
||||
}
|
||||
})
|
||||
|
||||
jest.mock('@/context/modal-context', () => ({
|
||||
useModalContext: () => ({
|
||||
setShowPricingModal: mockSetShowPricingModal,
|
||||
}),
|
||||
}))
|
||||
|
||||
const baseProps = {
|
||||
title: 'Upgrade Required',
|
||||
description: 'You need to upgrade your plan.',
|
||||
show: true,
|
||||
onClose: jest.fn(),
|
||||
}
|
||||
|
||||
const renderComponent = (props: Partial<React.ComponentProps<typeof PlanUpgradeModal>> = {}) => {
|
||||
const mergedProps = { ...baseProps, ...props }
|
||||
return render(<PlanUpgradeModal {...mergedProps} />)
|
||||
}
|
||||
|
||||
describe('PlanUpgradeModal', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
// Rendering and props-driven content
|
||||
it('should render modal with provided content when visible', () => {
|
||||
// Arrange
|
||||
const extraInfoText = 'Additional upgrade details'
|
||||
renderComponent({
|
||||
extraInfo: <div>{extraInfoText}</div>,
|
||||
})
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText(baseProps.title)).toBeInTheDocument()
|
||||
expect(screen.getByText(baseProps.description)).toBeInTheDocument()
|
||||
expect(screen.getByText(extraInfoText)).toBeInTheDocument()
|
||||
expect(screen.getByText('billing.triggerLimitModal.dismiss')).toBeInTheDocument()
|
||||
expect(screen.getByText('billing.triggerLimitModal.upgrade')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Guard against rendering when modal is hidden
|
||||
it('should not render content when show is false', () => {
|
||||
// Act
|
||||
renderComponent({ show: false })
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText(baseProps.title)).not.toBeInTheDocument()
|
||||
expect(screen.queryByText(baseProps.description)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// User closes the modal from dismiss button
|
||||
it('should call onClose when dismiss button is clicked', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
const onClose = jest.fn()
|
||||
renderComponent({ onClose })
|
||||
|
||||
// Act
|
||||
await user.click(screen.getByText('billing.triggerLimitModal.dismiss'))
|
||||
|
||||
// Assert
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Upgrade path uses provided callback over pricing modal
|
||||
it('should call onUpgrade and onClose when upgrade button is clicked with onUpgrade provided', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
const onClose = jest.fn()
|
||||
const onUpgrade = jest.fn()
|
||||
renderComponent({ onClose, onUpgrade })
|
||||
|
||||
// Act
|
||||
await user.click(screen.getByText('billing.triggerLimitModal.upgrade'))
|
||||
|
||||
// Assert
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
expect(onUpgrade).toHaveBeenCalledTimes(1)
|
||||
expect(mockSetShowPricingModal).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Fallback upgrade path opens pricing modal when no onUpgrade is supplied
|
||||
it('should open pricing modal when upgrade button is clicked without onUpgrade', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
const onClose = jest.fn()
|
||||
renderComponent({ onClose, onUpgrade: undefined })
|
||||
|
||||
// Act
|
||||
await user.click(screen.getByText('billing.triggerLimitModal.upgrade'))
|
||||
|
||||
// Assert
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
expect(mockSetShowPricingModal).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
87
web/app/components/billing/plan-upgrade-modal/index.tsx
Normal file
87
web/app/components/billing/plan-upgrade-modal/index.tsx
Normal file
@ -0,0 +1,87 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import Button from '@/app/components/base/button'
|
||||
import UpgradeBtn from '@/app/components/billing/upgrade-btn'
|
||||
import styles from './style.module.css'
|
||||
import { SquareChecklist } from '../../base/icons/src/vender/other'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
|
||||
type Props = {
|
||||
Icon?: React.ComponentType<React.SVGProps<SVGSVGElement>>
|
||||
title: string
|
||||
description: string
|
||||
extraInfo?: React.ReactNode
|
||||
show: boolean
|
||||
onClose: () => void
|
||||
onUpgrade?: () => void
|
||||
}
|
||||
|
||||
const PlanUpgradeModal: FC<Props> = ({
|
||||
Icon = SquareChecklist,
|
||||
title,
|
||||
description,
|
||||
extraInfo,
|
||||
show,
|
||||
onClose,
|
||||
onUpgrade,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { setShowPricingModal } = useModalContext()
|
||||
|
||||
const handleUpgrade = useCallback(() => {
|
||||
onClose()
|
||||
onUpgrade ? onUpgrade() : setShowPricingModal()
|
||||
}, [onClose, onUpgrade, setShowPricingModal])
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isShow={show}
|
||||
onClose={onClose}
|
||||
closable={false}
|
||||
clickOutsideNotClose
|
||||
className={`${styles.surface} w-[580px] rounded-2xl !p-0`}
|
||||
>
|
||||
<div className='relative'>
|
||||
<div
|
||||
aria-hidden
|
||||
className={`${styles.heroOverlay} pointer-events-none absolute inset-0`}
|
||||
/>
|
||||
<div className='px-8 pt-8'>
|
||||
<div className={`${styles.icon} flex size-12 items-center justify-center rounded-xl shadow-lg backdrop-blur-[5px]`}>
|
||||
<Icon className='size-6 text-text-primary-on-surface' />
|
||||
</div>
|
||||
<div className='mt-6 space-y-2'>
|
||||
<div className={`${styles.highlight} title-3xl-semi-bold`}>
|
||||
{title}
|
||||
</div>
|
||||
<div className='system-md-regular text-text-tertiary'>
|
||||
{description}
|
||||
</div>
|
||||
</div>
|
||||
{extraInfo}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='mb-8 mt-10 flex justify-end space-x-2 px-8'>
|
||||
<Button
|
||||
onClick={onClose}
|
||||
>
|
||||
{t('billing.triggerLimitModal.dismiss')}
|
||||
</Button>
|
||||
<UpgradeBtn
|
||||
size='custom'
|
||||
isShort
|
||||
onClick={handleUpgrade}
|
||||
className='!h-8 !rounded-lg px-2'
|
||||
labelKey='billing.triggerLimitModal.upgrade'
|
||||
loc='trigger-events-limit-modal'
|
||||
/>
|
||||
</div>
|
||||
</Modal>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(PlanUpgradeModal)
|
||||
@ -19,7 +19,6 @@
|
||||
background:
|
||||
linear-gradient(180deg, var(--color-components-avatar-bg-mask-stop-0, rgba(255, 255, 255, 0.12)) 0%, var(--color-components-avatar-bg-mask-stop-100, rgba(255, 255, 255, 0.08)) 100%),
|
||||
var(--color-util-colors-blue-brand-blue-brand-500, #296dff);
|
||||
box-shadow: 0 10px 20px color-mix(in srgb, var(--color-util-colors-blue-brand-blue-brand-500, #296dff) 35%, transparent);
|
||||
}
|
||||
|
||||
.highlight {
|
||||
@ -9,6 +9,7 @@ import Toast from '../../../../base/toast'
|
||||
import { PlanRange } from '../../plan-switcher/plan-range-switcher'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing'
|
||||
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
import List from './list'
|
||||
import Button from './button'
|
||||
import { Professional, Sandbox, Team } from '../../assets'
|
||||
@ -42,6 +43,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
|
||||
const isCurrentPaidPlan = isCurrent && !isFreePlan
|
||||
const isPlanDisabled = isCurrentPaidPlan ? false : planInfo.level <= ALL_PLANS[currentPlan].level
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const openAsyncWindow = useAsyncWindowOpen()
|
||||
|
||||
const btnText = useMemo(() => {
|
||||
if (isCurrent)
|
||||
@ -72,8 +74,16 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
|
||||
setLoading(true)
|
||||
try {
|
||||
if (isCurrentPaidPlan) {
|
||||
const res = await fetchBillingUrl()
|
||||
window.open(res.url, '_blank')
|
||||
await openAsyncWindow(async () => {
|
||||
const res = await fetchBillingUrl()
|
||||
if (res.url)
|
||||
return res.url
|
||||
throw new Error('Failed to open billing page')
|
||||
}, {
|
||||
onError: (err) => {
|
||||
Toast.notify({ type: 'error', message: err.message || String(err) })
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -2,27 +2,22 @@
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { TriggerAll } from '@/app/components/base/icons/src/vender/workflow'
|
||||
import UsageInfo from '@/app/components/billing/usage-info'
|
||||
import UpgradeBtn from '@/app/components/billing/upgrade-btn'
|
||||
import type { Plan } from '@/app/components/billing/type'
|
||||
import styles from './index.module.css'
|
||||
import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal'
|
||||
|
||||
type Props = {
|
||||
show: boolean
|
||||
onDismiss: () => void
|
||||
onClose: () => void
|
||||
onUpgrade: () => void
|
||||
usage: number
|
||||
total: number
|
||||
resetInDays?: number
|
||||
planType: Plan
|
||||
}
|
||||
|
||||
const TriggerEventsLimitModal: FC<Props> = ({
|
||||
show,
|
||||
onDismiss,
|
||||
onClose,
|
||||
onUpgrade,
|
||||
usage,
|
||||
total,
|
||||
@ -31,59 +26,25 @@ const TriggerEventsLimitModal: FC<Props> = ({
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isShow={show}
|
||||
onClose={onDismiss}
|
||||
closable={false}
|
||||
clickOutsideNotClose
|
||||
className={`${styles.surface} flex h-[360px] w-[580px] flex-col overflow-hidden rounded-2xl !p-0 shadow-xl`}
|
||||
>
|
||||
<div className='relative flex w-full flex-1 items-stretch justify-center'>
|
||||
<div
|
||||
aria-hidden
|
||||
className={`${styles.heroOverlay} pointer-events-none absolute inset-0`}
|
||||
<PlanUpgradeModal
|
||||
show={show}
|
||||
onClose={onClose}
|
||||
onUpgrade={onUpgrade}
|
||||
Icon={TriggerAll as React.ComponentType<React.SVGProps<SVGSVGElement>>}
|
||||
title={t('billing.triggerLimitModal.title')}
|
||||
description={t('billing.triggerLimitModal.description')}
|
||||
extraInfo={(
|
||||
<UsageInfo
|
||||
className='mt-4 w-full rounded-[12px] bg-components-panel-on-panel-item-bg'
|
||||
Icon={TriggerAll}
|
||||
name={t('billing.triggerLimitModal.usageTitle')}
|
||||
usage={usage}
|
||||
total={total}
|
||||
resetInDays={resetInDays}
|
||||
hideIcon
|
||||
/>
|
||||
<div className='relative z-10 flex w-full flex-col items-start gap-4 px-8 pt-8'>
|
||||
<div className={`${styles.icon} flex h-12 w-12 items-center justify-center rounded-[12px]`}>
|
||||
<TriggerAll className='h-5 w-5 text-text-primary-on-surface' />
|
||||
</div>
|
||||
<div className='flex flex-col items-start gap-2'>
|
||||
<div className={`${styles.highlight} title-lg-semi-bold`}>
|
||||
{t('billing.triggerLimitModal.title')}
|
||||
</div>
|
||||
<div className='body-md-regular text-text-secondary'>
|
||||
{t('billing.triggerLimitModal.description')}
|
||||
</div>
|
||||
</div>
|
||||
<UsageInfo
|
||||
className='mb-5 w-full rounded-[12px] bg-components-panel-on-panel-item-bg'
|
||||
Icon={TriggerAll}
|
||||
name={t('billing.triggerLimitModal.usageTitle')}
|
||||
usage={usage}
|
||||
total={total}
|
||||
resetInDays={resetInDays}
|
||||
hideIcon
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='flex h-[76px] w-full items-center justify-end gap-2 px-8 pb-8 pt-5'>
|
||||
<Button
|
||||
className='h-8 w-[77px] min-w-[72px] !rounded-lg !border-[0.5px] px-3 py-2'
|
||||
onClick={onDismiss}
|
||||
>
|
||||
{t('billing.triggerLimitModal.dismiss')}
|
||||
</Button>
|
||||
<UpgradeBtn
|
||||
isShort
|
||||
onClick={onUpgrade}
|
||||
className='flex w-[93px] items-center justify-center !rounded-lg !px-2'
|
||||
style={{ height: 32 }}
|
||||
labelKey='billing.triggerLimitModal.upgrade'
|
||||
loc='trigger-events-limit-modal'
|
||||
/>
|
||||
</div>
|
||||
</Modal>
|
||||
)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ type Props = {
|
||||
className?: string
|
||||
style?: CSSProperties
|
||||
isFull?: boolean
|
||||
size?: 'md' | 'lg'
|
||||
size?: 's' | 'm' | 'custom'
|
||||
isPlain?: boolean
|
||||
isShort?: boolean
|
||||
onClick?: () => void
|
||||
@ -21,6 +21,7 @@ type Props = {
|
||||
|
||||
const UpgradeBtn: FC<Props> = ({
|
||||
className,
|
||||
size = 'm',
|
||||
style,
|
||||
isPlain = false,
|
||||
isShort = false,
|
||||
@ -62,7 +63,7 @@ const UpgradeBtn: FC<Props> = ({
|
||||
|
||||
return (
|
||||
<PremiumBadge
|
||||
size='m'
|
||||
size={size}
|
||||
color='blue'
|
||||
allowHover={true}
|
||||
onClick={onClick}
|
||||
|
||||
@ -2,7 +2,7 @@ import cn from '@/utils/classnames'
|
||||
import React, { useCallback, useMemo, useState } from 'react'
|
||||
|
||||
type CredentialIconProps = {
|
||||
avatar_url?: string
|
||||
avatarUrl?: string
|
||||
name: string
|
||||
size?: number
|
||||
className?: string
|
||||
@ -16,12 +16,12 @@ const ICON_BG_COLORS = [
|
||||
]
|
||||
|
||||
export const CredentialIcon: React.FC<CredentialIconProps> = ({
|
||||
avatar_url,
|
||||
avatarUrl,
|
||||
name,
|
||||
size = 20,
|
||||
className = '',
|
||||
}) => {
|
||||
const [showAvatar, setShowAvatar] = useState(!!avatar_url && avatar_url !== 'default')
|
||||
const [showAvatar, setShowAvatar] = useState(!!avatarUrl && avatarUrl !== 'default')
|
||||
const firstLetter = useMemo(() => name.charAt(0).toUpperCase(), [name])
|
||||
const bgColor = useMemo(() => ICON_BG_COLORS[firstLetter.charCodeAt(0) % ICON_BG_COLORS.length], [firstLetter])
|
||||
|
||||
@ -29,17 +29,20 @@ export const CredentialIcon: React.FC<CredentialIconProps> = ({
|
||||
setShowAvatar(false)
|
||||
}, [])
|
||||
|
||||
if (avatar_url && avatar_url !== 'default' && showAvatar) {
|
||||
if (avatarUrl && avatarUrl !== 'default' && showAvatar) {
|
||||
return (
|
||||
<div
|
||||
className='flex shrink-0 items-center justify-center overflow-hidden rounded-md border border-divider-regular'
|
||||
className={cn(
|
||||
'flex shrink-0 items-center justify-center overflow-hidden rounded-md border border-divider-regular',
|
||||
className,
|
||||
)}
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
>
|
||||
<img
|
||||
src={avatar_url}
|
||||
src={avatarUrl}
|
||||
width={size}
|
||||
height={size}
|
||||
className={cn('shrink-0 object-contain', className)}
|
||||
className='shrink-0 object-contain'
|
||||
onError={onImgLoadError}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@ -25,7 +25,7 @@ type IFileUploaderProps = {
|
||||
onFileUpdate: (fileItem: FileItem, progress: number, list: FileItem[]) => void
|
||||
onFileListUpdate?: (files: FileItem[]) => void
|
||||
onPreview: (file: File) => void
|
||||
notSupportBatchUpload?: boolean
|
||||
supportBatchUpload?: boolean
|
||||
}
|
||||
|
||||
const FileUploader = ({
|
||||
@ -35,7 +35,7 @@ const FileUploader = ({
|
||||
onFileUpdate,
|
||||
onFileListUpdate,
|
||||
onPreview,
|
||||
notSupportBatchUpload,
|
||||
supportBatchUpload = false,
|
||||
}: IFileUploaderProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
@ -44,7 +44,7 @@ const FileUploader = ({
|
||||
const dropRef = useRef<HTMLDivElement>(null)
|
||||
const dragRef = useRef<HTMLDivElement>(null)
|
||||
const fileUploader = useRef<HTMLInputElement>(null)
|
||||
const hideUpload = notSupportBatchUpload && fileList.length > 0
|
||||
const hideUpload = !supportBatchUpload && fileList.length > 0
|
||||
|
||||
const { data: fileUploadConfigResponse } = useFileUploadConfig()
|
||||
const { data: supportFileTypesResponse } = useFileSupportTypes()
|
||||
@ -68,9 +68,9 @@ const FileUploader = ({
|
||||
const ACCEPTS = supportTypes.map((ext: string) => `.${ext}`)
|
||||
const fileUploadConfig = useMemo(() => ({
|
||||
file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15,
|
||||
batch_count_limit: fileUploadConfigResponse?.batch_count_limit ?? 5,
|
||||
file_upload_limit: fileUploadConfigResponse?.file_upload_limit ?? 5,
|
||||
}), [fileUploadConfigResponse])
|
||||
batch_count_limit: supportBatchUpload ? (fileUploadConfigResponse?.batch_count_limit ?? 5) : 1,
|
||||
file_upload_limit: supportBatchUpload ? (fileUploadConfigResponse?.file_upload_limit ?? 5) : 1,
|
||||
}), [fileUploadConfigResponse, supportBatchUpload])
|
||||
|
||||
const fileListRef = useRef<FileItem[]>([])
|
||||
|
||||
@ -254,12 +254,12 @@ const FileUploader = ({
|
||||
}),
|
||||
)
|
||||
let files = nested.flat()
|
||||
if (notSupportBatchUpload) files = files.slice(0, 1)
|
||||
if (!supportBatchUpload) files = files.slice(0, 1)
|
||||
files = files.slice(0, fileUploadConfig.batch_count_limit)
|
||||
const valid = files.filter(isValid)
|
||||
initialUpload(valid)
|
||||
},
|
||||
[initialUpload, isValid, notSupportBatchUpload, traverseFileEntry, fileUploadConfig],
|
||||
[initialUpload, isValid, supportBatchUpload, traverseFileEntry, fileUploadConfig],
|
||||
)
|
||||
const selectHandle = () => {
|
||||
if (fileUploader.current)
|
||||
@ -303,7 +303,7 @@ const FileUploader = ({
|
||||
id="fileUploader"
|
||||
className="hidden"
|
||||
type="file"
|
||||
multiple={!notSupportBatchUpload}
|
||||
multiple={supportBatchUpload}
|
||||
accept={ACCEPTS.join(',')}
|
||||
onChange={fileChangeHandle}
|
||||
/>
|
||||
@ -317,7 +317,7 @@ const FileUploader = ({
|
||||
<RiUploadCloud2Line className='mr-2 size-5' />
|
||||
|
||||
<span>
|
||||
{notSupportBatchUpload ? t('datasetCreation.stepOne.uploader.buttonSingleFile') : t('datasetCreation.stepOne.uploader.button')}
|
||||
{supportBatchUpload ? t('datasetCreation.stepOne.uploader.button') : t('datasetCreation.stepOne.uploader.buttonSingleFile')}
|
||||
{supportTypes.length > 0 && (
|
||||
<label className="ml-1 cursor-pointer text-text-accent" onClick={selectHandle}>{t('datasetCreation.stepOne.uploader.browse')}</label>
|
||||
)}
|
||||
@ -326,7 +326,7 @@ const FileUploader = ({
|
||||
<div>{t('datasetCreation.stepOne.uploader.tip', {
|
||||
size: fileUploadConfig.file_size_limit,
|
||||
supportTypes: supportTypesShowNames,
|
||||
batchCount: notSupportBatchUpload ? 1 : fileUploadConfig.batch_count_limit,
|
||||
batchCount: fileUploadConfig.batch_count_limit,
|
||||
totalCount: fileUploadConfig.file_upload_limit,
|
||||
})}</div>
|
||||
{dragging && <div ref={dragRef} className='absolute left-0 top-0 h-full w-full' />}
|
||||
|
||||
@ -22,6 +22,10 @@ import classNames from '@/utils/classnames'
|
||||
import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config'
|
||||
import NotionConnector from '@/app/components/base/notion-connector'
|
||||
import type { DataSourceAuth } from '@/app/components/header/account-setting/data-source-page-new/types'
|
||||
import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import UpgradeCard from './upgrade-card'
|
||||
|
||||
type IStepOneProps = {
|
||||
datasetId?: string
|
||||
@ -52,7 +56,7 @@ const StepOne = ({
|
||||
dataSourceTypeDisable,
|
||||
changeType,
|
||||
onSetting,
|
||||
onStepChange,
|
||||
onStepChange: doOnStepChange,
|
||||
files,
|
||||
updateFileList,
|
||||
updateFile,
|
||||
@ -110,7 +114,33 @@ const StepOne = ({
|
||||
const hasNotin = notionPages.length > 0
|
||||
const isVectorSpaceFull = plan.usage.vectorSpace >= plan.total.vectorSpace
|
||||
const isShowVectorSpaceFull = (allFileLoaded || hasNotin) && isVectorSpaceFull && enableBilling
|
||||
const notSupportBatchUpload = enableBilling && plan.type === 'sandbox'
|
||||
const supportBatchUpload = !enableBilling || plan.type !== Plan.sandbox
|
||||
const notSupportBatchUpload = !supportBatchUpload
|
||||
|
||||
const [isShowPlanUpgradeModal, {
|
||||
setTrue: showPlanUpgradeModal,
|
||||
setFalse: hidePlanUpgradeModal,
|
||||
}] = useBoolean(false)
|
||||
const onStepChange = useCallback(() => {
|
||||
if (notSupportBatchUpload) {
|
||||
let isMultiple = false
|
||||
if (dataSourceType === DataSourceType.FILE && files.length > 1)
|
||||
isMultiple = true
|
||||
|
||||
if (dataSourceType === DataSourceType.NOTION && notionPages.length > 1)
|
||||
isMultiple = true
|
||||
|
||||
if (dataSourceType === DataSourceType.WEB && websitePages.length > 1)
|
||||
isMultiple = true
|
||||
|
||||
if (isMultiple) {
|
||||
showPlanUpgradeModal()
|
||||
return
|
||||
}
|
||||
}
|
||||
doOnStepChange()
|
||||
}, [dataSourceType, doOnStepChange, files.length, notSupportBatchUpload, notionPages.length, showPlanUpgradeModal, websitePages.length])
|
||||
|
||||
const nextDisabled = useMemo(() => {
|
||||
if (!files.length)
|
||||
return true
|
||||
@ -229,7 +259,7 @@ const StepOne = ({
|
||||
onFileListUpdate={updateFileList}
|
||||
onFileUpdate={updateFile}
|
||||
onPreview={updateCurrentFile}
|
||||
notSupportBatchUpload={notSupportBatchUpload}
|
||||
supportBatchUpload={supportBatchUpload}
|
||||
/>
|
||||
{isShowVectorSpaceFull && (
|
||||
<div className='mb-4 max-w-[640px]'>
|
||||
@ -244,6 +274,14 @@ const StepOne = ({
|
||||
</span>
|
||||
</Button>
|
||||
</div>
|
||||
{
|
||||
enableBilling && plan.type === Plan.sandbox && files.length > 0 && (
|
||||
<div className='mt-5'>
|
||||
<div className='mb-4 h-px bg-divider-subtle'></div>
|
||||
<UpgradeCard />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</>
|
||||
)}
|
||||
{dataSourceType === DataSourceType.NOTION && (
|
||||
@ -330,6 +368,14 @@ const StepOne = ({
|
||||
/>
|
||||
)}
|
||||
{currentWebsite && <WebsitePreview payload={currentWebsite} hidePreview={hideWebsitePreview} />}
|
||||
{isShowPlanUpgradeModal && (
|
||||
<PlanUpgradeModal
|
||||
show
|
||||
onClose={hidePlanUpgradeModal}
|
||||
title={t('billing.upgrade.uploadMultiplePages.title')!}
|
||||
description={t('billing.upgrade.uploadMultiplePages.description')!}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
33
web/app/components/datasets/create/step-one/upgrade-card.tsx
Normal file
33
web/app/components/datasets/create/step-one/upgrade-card.tsx
Normal file
@ -0,0 +1,33 @@
|
||||
'use client'
|
||||
import UpgradeBtn from '@/app/components/billing/upgrade-btn'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const UpgradeCard: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const { setShowPricingModal } = useModalContext()
|
||||
|
||||
const handleUpgrade = useCallback(() => {
|
||||
setShowPricingModal()
|
||||
}, [setShowPricingModal])
|
||||
|
||||
return (
|
||||
<div className='flex items-center justify-between rounded-xl border-[0.5px] border-components-panel-border-subtle bg-components-panel-on-panel-item-bg py-3 pl-4 pr-3.5 shadow-xs backdrop-blur-[5px] '>
|
||||
<div>
|
||||
<div className='title-md-semi-bold bg-[linear-gradient(92deg,_var(--components-input-border-active-prompt-1,_#0BA5EC)_0%,_var(--components-input-border-active-prompt-2,_#155AEF)_99.21%)] bg-clip-text text-transparent'>{t('billing.upgrade.uploadMultipleFiles.title')}</div>
|
||||
<div className='system-xs-regular text-text-tertiary'>{t('billing.upgrade.uploadMultipleFiles.description')}</div>
|
||||
</div>
|
||||
<UpgradeBtn
|
||||
size='custom'
|
||||
isShort
|
||||
className='ml-3 !h-8 !rounded-lg px-2'
|
||||
labelKey='billing.triggerLimitModal.upgrade'
|
||||
loc='upload-multiple-files'
|
||||
onClick={handleUpgrade}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(UpgradeCard)
|
||||
@ -166,10 +166,6 @@ const FireCrawl: FC<Props> = ({
|
||||
setCrawlErrorMessage(errorMessage || t(`${I18N_PREFIX}.unknownError`))
|
||||
}
|
||||
else {
|
||||
data.data = data.data.map((item: any) => ({
|
||||
...item,
|
||||
content: item.markdown,
|
||||
}))
|
||||
setCrawlResult(data)
|
||||
onCheckedCrawlResultChange(data.data || []) // default select the crawl result
|
||||
setCrawlErrorMessage('')
|
||||
|
||||
@ -157,7 +157,7 @@ const JinaReader: FC<Props> = ({
|
||||
total: 1,
|
||||
data: [{
|
||||
title,
|
||||
content,
|
||||
markdown: content,
|
||||
description,
|
||||
source_url: url,
|
||||
}],
|
||||
|
||||
@ -32,7 +32,7 @@ const WebsitePreview = ({
|
||||
<div className='system-xs-medium truncate text-text-tertiary' title={payload.source_url}>{payload.source_url}</div>
|
||||
</div>
|
||||
<div className={cn(s.previewContent, 'body-md-regular')}>
|
||||
<div className={cn(s.fileContent)}>{payload.content}</div>
|
||||
<div className={cn(s.fileContent)}>{payload.markdown}</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
@ -132,7 +132,7 @@ const WaterCrawl: FC<Props> = ({
|
||||
},
|
||||
}
|
||||
}
|
||||
}, [crawlOptions.limit])
|
||||
}, [crawlOptions.limit, onCheckedCrawlResultChange])
|
||||
|
||||
const handleRun = useCallback(async (url: string) => {
|
||||
const { isValid, errorMsg } = checkValid(url)
|
||||
@ -174,7 +174,7 @@ const WaterCrawl: FC<Props> = ({
|
||||
finally {
|
||||
setStep(Step.finished)
|
||||
}
|
||||
}, [checkValid, crawlOptions, onJobIdChange, t, waitForCrawlFinished])
|
||||
}, [checkValid, crawlOptions, onCheckedCrawlResultChange, onJobIdChange, t, waitForCrawlFinished])
|
||||
|
||||
return (
|
||||
<div>
|
||||
|
||||
@ -10,14 +10,12 @@ import Trigger from './trigger'
|
||||
import List from './list'
|
||||
|
||||
export type CredentialSelectorProps = {
|
||||
pluginName: string
|
||||
currentCredentialId: string
|
||||
onCredentialChange: (credentialId: string) => void
|
||||
credentials: Array<DataSourceCredential>
|
||||
}
|
||||
|
||||
const CredentialSelector = ({
|
||||
pluginName,
|
||||
currentCredentialId,
|
||||
onCredentialChange,
|
||||
credentials,
|
||||
@ -50,7 +48,6 @@ const CredentialSelector = ({
|
||||
<PortalToFollowElemTrigger onClick={toggle} className='grow overflow-hidden'>
|
||||
<Trigger
|
||||
currentCredential={currentCredential}
|
||||
pluginName={pluginName}
|
||||
isOpen={open}
|
||||
/>
|
||||
</PortalToFollowElemTrigger>
|
||||
@ -58,7 +55,6 @@ const CredentialSelector = ({
|
||||
<List
|
||||
currentCredentialId={currentCredentialId}
|
||||
credentials={credentials}
|
||||
pluginName={pluginName}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
</PortalToFollowElemContent>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user