Compare commits

..

2 Commits

Author SHA1 Message Date
225238b4b2 Update dev/ast-grep/rules/remove-nullable-arg.yaml
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2025-11-05 03:09:34 +08:00
c4ea3e47fd refactor: enforce typed String mapped columns 2025-11-05 03:09:34 +08:00
204 changed files with 3527 additions and 9883 deletions

View File

@ -53,6 +53,8 @@ jobs:
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
# Rewrite SQLAlchemy with Type Annotations
uvx --from ast-grep-cli sg scan -r dev/ast-grep/rules/remove-nullable-arg.yaml api/models -U
- name: mdformat
run: |

View File

@ -117,7 +117,7 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
- **Dify for enterprise / organizations<br/>**
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. <br/>
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.

View File

@ -615,8 +615,5 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0

View File

@ -1422,10 +1422,7 @@ def setup_datasource_oauth_client(provider, client_params):
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
@click.option(
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
)
def transform_datasource_credentials(environment: str):
def transform_datasource_credentials():
"""
Transform datasource credentials
"""
@ -1436,14 +1433,9 @@ def transform_datasource_credentials(environment: str):
notion_plugin_id = "langgenius/notion_datasource"
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource"
if environment == "online":
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
else:
notion_plugin_unique_identifier = None
firecrawl_plugin_unique_identifier = None
jina_plugin_unique_identifier = None
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
oauth_credential_type = CredentialType.OAUTH2
api_key_credential_type = CredentialType.API_KEY
@ -1609,7 +1601,7 @@ def transform_datasource_credentials(environment: str):
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jinareader",
provider="jina",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,

View File

@ -1142,13 +1142,6 @@ class SwaggerUIConfig(BaseSettings):
)
class TenantIsolatedTaskQueueConfig(BaseSettings):
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
default=1,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1173,7 +1166,6 @@ class FeatureConfig(
RagEtlConfig,
RepositoryConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,

View File

@ -22,11 +22,6 @@ class WeaviateConfig(BaseSettings):
default=True,
)
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
default=None,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)",
default=100,

View File

@ -102,18 +102,7 @@ class DraftWorkflowApi(Resource):
},
)
)
@api.response(
200,
"Draft workflow synced successfully",
api.model(
"SyncDraftWorkflowResponse",
{
"result": fields.String,
"hash": fields.String,
"updated_at": fields.String,
},
),
)
@api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied")
@edit_permission_required

View File

@ -67,7 +67,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
kwargs["app_model"] = app_model
# If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get("user")
@ -76,6 +75,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get("user")
else:
# use default-user
user_id = None
if not user_id and fetch_user_arg.required:
@ -90,28 +90,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
# Set EndUser as current logged-in user for flask_login.current_user
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
else:
# For service API without end-user context, ensure an Account is logged in
# so services relying on current_account_with_tenant() work correctly.
tenant_owner_info = (
db.session.query(Tenant, Account)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.join(Account, TenantAccountJoin.account_id == Account.id)
.where(
Tenant.id == app_model.tenant_id,
TenantAccountJoin.role == "owner",
Tenant.status == TenantStatus.NORMAL,
)
.one_or_none()
)
if tenant_owner_info:
tenant_model, account = tenant_owner_info
account.current_tenant = tenant_model
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account not found or tenant is not active.")
return view_func(*args, **kwargs)

View File

@ -40,15 +40,20 @@ from core.workflow.repositories.draft_variable_repository import DraftVariableSa
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
@ -244,7 +249,34 @@ class PipelineGenerator(BaseAppGenerator):
)
if rag_pipeline_invoke_entities:
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.enabled and features.billing.subscription.plan == CloudPlan.SANDBOX:
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
# return batch, dataset, documents
return {
"batch": batch,

View File

@ -104,11 +104,6 @@ class AppGenerateEntity(BaseModel):
inputs: Mapping[str, Any]
files: Sequence[File]
# Unique identifier of the user initiating the execution.
# This corresponds to `Account.id` for platform users or `EndUser.id` for end users.
#
# Note: The `user_id` field does not indicate whether the user is a platform user or an end user.
user_id: str
# extras

View File

@ -1,64 +1,15 @@
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunPausedEvent
from models.model import AppMode
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
# Wrapper types for `WorkflowAppGenerateEntity` and
# `AdvancedChatAppGenerateEntity`. These wrappers enable type discrimination
# and correct reconstruction of the entity field during (de)serialization.
class _WorkflowGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.WORKFLOW] = AppMode.WORKFLOW
entity: WorkflowAppGenerateEntity
class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.ADVANCED_CHAT] = AppMode.ADVANCED_CHAT
entity: AdvancedChatAppGenerateEntity
_GenerateEntityUnion: TypeAlias = Annotated[
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
Field(discriminator="type"),
]
class WorkflowResumptionContext(BaseModel):
"""WorkflowResumptionContext captures all state necessary for resumption."""
version: Literal["1"] = "1"
# Only workflow / chatflow could be paused.
generate_entity: _GenerateEntityUnion
serialized_graph_runtime_state: str
def dumps(self) -> str:
return self.model_dump_json()
@classmethod
def loads(cls, value: str) -> Self:
return cls.model_validate_json(value)
def get_generate_entity(self) -> WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity:
return self.generate_entity.entity
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(
self,
session_factory: Engine | sessionmaker,
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
state_owner_user_id: str,
):
def __init__(self, session_factory: Engine | sessionmaker, state_owner_user_id: str):
"""Create a PauseStatePersistenceLayer.
The `state_owner_user_id` is used when creating state file for pause.
@ -68,7 +19,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
session_factory = sessionmaker(session_factory)
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
@ -99,27 +49,13 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
return
assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
elif isinstance(self._generate_entity, AdvancedChatAppGenerateEntity):
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
else:
raise AssertionError(f"unknown entity type: type={type(self._generate_entity)}")
state = WorkflowResumptionContext(
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),
generate_entity=entity_wrapper,
)
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
assert workflow_run_id is not None
repo = self._get_repo()
repo.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=state.dumps(),
state=self.graph_runtime_state.dumps(),
)
def on_graph_end(self, error: Exception | None) -> None:

View File

@ -1,15 +0,0 @@
from collections.abc import Sequence
from dataclasses import dataclass
@dataclass
class DocumentTask:
"""Document task entity for document indexing operations.
This class represents a document indexing task that can be queued
and processed by the document indexing system.
"""
tenant_id: str
dataset_id: str
document_ids: Sequence[str]

View File

@ -1533,9 +1533,6 @@ class ProviderConfiguration(BaseModel):
# Return composite sort key: (model_type value, model position index)
return (model.model_type.value, position_index)
# Deduplicate
provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values())
# Sort using the composite sort key
return sorted(provider_models, key=get_sort_key)

View File

@ -6,7 +6,10 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f""" {cls._code_placeholder}
runner_script = dedent(
f"""
// declare main function
{cls._code_placeholder}
// decode and prepare input object
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
@ -18,5 +21,6 @@ class NodeJsTemplateTransformer(TemplateTransformer):
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
console.log(result)
""")
"""
)
return runner_script

View File

@ -6,7 +6,9 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f""" {cls._code_placeholder}
runner_script = dedent(f"""
# declare main function
{cls._code_placeholder}
import json
from base64 import b64decode

View File

@ -1,22 +1,21 @@
import hashlib
import json
import logging
import os
import traceback
from datetime import datetime, timedelta
from typing import Any, Union, cast
from urllib.parse import urlparse
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import AttributeValue
from sqlalchemy.orm import sessionmaker
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from sqlalchemy import select
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
@ -31,10 +30,9 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecutionTriggeredFrom
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
@ -101,45 +99,22 @@ def datetime_to_nanos(dt: datetime | None) -> int:
return int(dt.timestamp() * 1_000_000_000)
def error_to_string(error: Exception | str | None) -> str:
"""Convert an error to a string with traceback information."""
error_message = "Empty Stack Trace"
if error:
if isinstance(error, Exception):
string_stacktrace = "".join(traceback.format_exception(error))
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
else:
error_message = str(error)
return error_message
def string_to_trace_id128(string: str | None) -> int:
"""
Convert any input string into a stable 128-bit integer trace ID.
This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest.
It's suitable for generating consistent, unique identifiers from strings.
"""
if string is None:
string = ""
hash_object = hashlib.sha256(string.encode())
def set_span_status(current_span: Span, error: Exception | str | None = None):
"""Set the status of the current span based on the presence of an error."""
if error:
error_string = error_to_string(error)
current_span.set_status(Status(StatusCode.ERROR, error_string))
# Take the first 16 bytes (128 bits) of the hash digest
digest = hash_object.digest()[:16]
if isinstance(error, Exception):
current_span.record_exception(error)
else:
exception_type = error.__class__.__name__
exception_message = str(error)
if not exception_message:
exception_message = repr(error)
attributes: dict[str, AttributeValue] = {
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
}
current_span.add_event(name="exception", attributes=attributes)
else:
current_span.set_status(Status(StatusCode.OK))
def safe_json_dumps(obj: Any) -> str:
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
return json.dumps(obj, default=str, ensure_ascii=False)
# Convert to a 128-bit integer
return int.from_bytes(digest, byteorder="big")
class ArizePhoenixDataTrace(BaseTraceInstance):
@ -156,12 +131,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.propagator = TraceContextTextMapPropagator()
self.dify_trace_ids: set[str] = set()
def trace(self, trace_info: BaseTraceInfo):
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
logger.info("[Arize/Phoenix] Trace: %s", trace_info)
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
@ -179,7 +151,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
except Exception as e:
logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True)
logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True)
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
@ -194,9 +166,15 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
workflow_metadata.update(trace_info.metadata)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
@ -208,58 +186,31 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
)
# Through workflow_run_id, get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
for node_execution in workflow_node_executions:
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
inputs_value = node_execution.inputs or {}
outputs_value = node_execution.outputs or {}
# Process workflow nodes
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
process_data = node_execution.process_data or {}
execution_metadata = node_execution.metadata or {}
node_metadata = {str(k): v for k, v in execution_metadata.items()}
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
node_metadata.update(
{
"node_id": node_execution.id,
"node_type": node_execution.node_type,
"node_status": node_execution.status,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
}
)
node_metadata = {
"node_id": node_execution.id,
"node_type": node_execution.node_type,
"node_status": node_execution.status,
"tenant_id": node_execution.tenant_id,
"app_id": node_execution.app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
}
if node_execution.execution_metadata:
node_metadata.update(json.loads(node_execution.execution_metadata))
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN
@ -272,9 +223,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
if model:
node_metadata["ls_model_name"] = model
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
)
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
@ -286,20 +236,17 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
else:
span_kind = OpenInferenceSpanKindValues.CHAIN
workflow_span_context = set_span_in_context(workflow_span)
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(created_at),
context=workflow_span_context,
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
@ -313,8 +260,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
if model:
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
outputs = (
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
)
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
)
if usage_data:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
@ -325,16 +275,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
finally:
if node_execution.status == "failed":
set_span_status(node_span, node_execution.error)
else:
set_span_status(node_span)
node_span.end(end_time=datetime_to_nanos(finished_at))
finally:
if trace_info.error:
set_span_status(workflow_span, trace_info.error)
else:
set_span_status(workflow_span)
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def message_trace(self, trace_info: MessageTraceInfo):
@ -380,18 +322,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id)
message_span_id = RandomIdGenerator().generate_span_id()
span_context = SpanContext(
trace_id=trace_id,
span_id=message_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
message_span = self.tracer.start_span(
name=TraceTaskName.MESSAGE_TRACE.value,
attributes=attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
message_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
# Convert outputs to string based on type
if isinstance(trace_info.outputs, dict | list):
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
@ -425,26 +383,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
if model_params := metadata_dict.get("model_parameters"):
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
message_span_context = set_span_in_context(message_span)
llm_span = self.tracer.start_span(
name="llm",
attributes=llm_attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=message_span_context,
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.message_data.error:
set_span_status(llm_span, trace_info.message_data.error)
else:
set_span_status(llm_span)
if trace_info.error:
llm_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
finally:
if trace_info.error:
set_span_status(message_span, trace_info.error)
else:
set_span_status(message_span)
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def moderation_trace(self, trace_info: ModerationTraceInfo):
@ -460,9 +418,15 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.MODERATION_TRACE.value,
@ -481,14 +445,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
@ -511,9 +480,15 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
@ -524,14 +499,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(start_time),
context=root_span_context,
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.error:
set_span_status(span, trace_info.error)
else:
set_span_status(span)
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(end_time))
@ -553,9 +533,15 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
@ -568,14 +554,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"end_time": end_time.isoformat() if end_time else "",
},
start_time=datetime_to_nanos(start_time),
context=root_span_context,
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(end_time))
@ -589,9 +580,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
}
dify_trace_id = trace_info.trace_id or trace_info.message_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
trace_id = string_to_trace_id128(trace_info.message_id)
tool_span_id = RandomIdGenerator().generate_span_id()
logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id)
# Create span context with the same trace_id as the parent
# todo: Create with the appropriate parent span context, so that the tool span is
# a child of the appropriate span (e.g. message span)
span_context = SpanContext(
trace_id=trace_id,
span_id=tool_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
tool_params_str = (
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
@ -610,14 +612,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
set_span_status(span, trace_info.error)
else:
set_span_status(span)
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
@ -634,9 +641,15 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
metadata.update(trace_info.metadata)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
trace_id = string_to_trace_id128(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
@ -650,34 +663,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
set_span_status(span, trace_info.message_data.error)
else:
set_span_status(span)
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def ensure_root_span(self, dify_trace_id: str | None):
"""Ensure a unique root span exists for the given Dify trace ID."""
if str(dify_trace_id) not in self.dify_trace_ids:
self.carrier: dict[str, str] = {}
root_span = self.tracer.start_span(name="Dify")
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
root_span.set_attribute("dify_project_name", str(self.project))
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
with use_span(root_span, end_on_exit=False):
self.propagator.inject(carrier=self.carrier)
set_span_status(root_span)
root_span.end()
self.dify_trace_ids.add(str(dify_trace_id))
def api_check(self):
try:
with self.tracer.start_span("api_check") as span:
@ -697,6 +698,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = db.session.scalars(
select(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.title,
WorkflowNodeExecutionModel.node_type,
WorkflowNodeExecutionModel.status,
WorkflowNodeExecutionModel.inputs,
WorkflowNodeExecutionModel.outputs,
WorkflowNodeExecutionModel.created_at,
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
).all()
return workflow_nodes
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
"""Helper method to construct LLM attributes with passed prompts."""
attributes = {}

View File

@ -147,8 +147,7 @@ class ElasticSearchVector(BaseVector):
def _get_version(self) -> str:
info = self._client.info()
# remove any suffix like "-SNAPSHOT" from the version string
return cast(str, info["version"]["number"]).split("-")[0]
return cast(str, info["version"]["number"])
def _check_version(self):
if parse_version(self._version) < parse_version("8.0.0"):

View File

@ -39,13 +39,11 @@ class WeaviateConfig(BaseModel):
Attributes:
endpoint: Weaviate server endpoint URL
grpc_endpoint: Optional Weaviate gRPC server endpoint URL
api_key: Optional API key for authentication
batch_size: Number of objects to batch per insert operation
"""
endpoint: str
grpc_endpoint: str | None = None
api_key: str | None = None
batch_size: int = 100
@ -90,22 +88,9 @@ class WeaviateVector(BaseVector):
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
# Parse gRPC configuration
if config.grpc_endpoint:
# Urls without scheme won't be parsed correctly in some python versions,
# see https://bugs.python.org/issue27657
grpc_endpoint_with_scheme = (
config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}"
)
grpc_p = urlparse(grpc_endpoint_with_scheme)
grpc_host = grpc_p.hostname or "localhost"
grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051)
grpc_secure = grpc_p.scheme == "grpcs"
else:
# Infer from HTTP endpoint as fallback
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client = weaviate.connect_to_custom(
http_host=host,
@ -447,7 +432,6 @@ class WeaviateVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
grpc_endpoint=dify_config.WEAVIATE_GRPC_ENDPOINT or "",
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),

View File

@ -1,79 +0,0 @@
import json
from collections.abc import Sequence
from typing import Any
from pydantic import BaseModel, ValidationError
from extensions.ext_redis import redis_client
_DEFAULT_TASK_TTL = 60 * 60 # 1 hour
class TaskWrapper(BaseModel):
data: Any
def serialize(self) -> str:
return self.model_dump_json()
@classmethod
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
return cls.model_validate_json(serialized_data)
class TenantIsolatedTaskQueue:
"""
Simple queue for tenant isolated tasks, used for rag related tenant tasks isolation.
It uses Redis list to store tasks, and Redis key to store task waiting flag.
Support tasks that can be serialized by json.
"""
def __init__(self, tenant_id: str, unique_key: str):
self._tenant_id = tenant_id
self._unique_key = unique_key
self._queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
self._task_key = f"tenant_{unique_key}_task:{tenant_id}"
def get_task_key(self):
return redis_client.get(self._task_key)
def set_task_waiting_time(self, ttl: int = _DEFAULT_TASK_TTL):
redis_client.setex(self._task_key, ttl, 1)
def delete_task_key(self):
redis_client.delete(self._task_key)
def push_tasks(self, tasks: Sequence[Any]):
serialized_tasks = []
for task in tasks:
# Store str list directly, maintaining full compatibility for pipeline scenarios
if isinstance(task, str):
serialized_tasks.append(task)
else:
# Use TaskWrapper to do JSON serialization for non-string tasks
wrapper = TaskWrapper(data=task)
serialized_data = wrapper.serialize()
serialized_tasks.append(serialized_data)
redis_client.lpush(self._queue, *serialized_tasks)
def pull_tasks(self, count: int = 1) -> Sequence[Any]:
if count <= 0:
return []
tasks = []
for _ in range(count):
serialized_task = redis_client.rpop(self._queue)
if not serialized_task:
break
if isinstance(serialized_task, bytes):
serialized_task = serialized_task.decode("utf-8")
try:
wrapper = TaskWrapper.deserialize(serialized_task)
tasks.append(wrapper.data)
except (json.JSONDecodeError, ValidationError, TypeError, ValueError):
# Fall back to raw string for legacy format or invalid JSON
tasks.append(serialized_task)
return tasks

View File

@ -210,13 +210,12 @@ class Tool(ABC):
meta=meta,
)
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
def create_json_message(self, object: dict) -> ToolInvokeMessage:
"""
create a json message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.JSON,
message=ToolInvokeMessage.JsonMessage(json_object=object, suppress_output=suppress_output),
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
)
def create_variable_message(

View File

@ -129,7 +129,6 @@ class ToolInvokeMessage(BaseModel):
class JsonMessage(BaseModel):
json_object: dict
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
blob: bytes

View File

@ -1,19 +1,16 @@
import base64
import json
import logging
from collections.abc import Generator
from typing import Any
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import AudioContent, CallToolResult, ImageContent, TextContent
from core.mcp.types import CallToolResult, ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
logger = logging.getLogger(__name__)
class MCPTool(Tool):
def __init__(
@ -55,11 +52,6 @@ class MCPTool(Tool):
yield from self._process_text_content(content)
elif isinstance(content, ImageContent):
yield self._process_image_content(content)
elif isinstance(content, AudioContent):
yield self._process_audio_content(content)
else:
logger.warning("Unsupported content type=%s", type(content))
# handle MCP structured output
if self.entity.output_schema and result.structuredContent:
for k, v in result.structuredContent.items():
@ -105,10 +97,6 @@ class MCPTool(Tool):
"""Process image content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def _process_audio_content(self, content: AudioContent) -> ToolInvokeMessage:
"""Process audio content and return a blob message."""
return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType})
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
return MCPTool(
entity=self.entity,

View File

@ -245,9 +245,6 @@ class ToolEngine:
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
json_message = cast(ToolInvokeMessage.JsonMessage, response.message)
if json_message.suppress_output:
continue
json_parts.append(
json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),

View File

@ -117,7 +117,7 @@ class WorkflowTool(Tool):
self._latest_usage = self._derive_usage_from_result(data)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs, suppress_output=True)
yield self.create_json_message(outputs)
@property
def latest_usage(self) -> LLMUsage:

View File

@ -16,6 +16,7 @@ from uuid import uuid4
from flask import Flask
from typing_extensions import override
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
@ -107,8 +108,8 @@ class Worker(threading.Thread):
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
node_id=node.id,
node_type=node.node_type,
node_id="unknown",
node_type=NodeType.CODE,
in_iteration_id=None,
error=str(e),
start_at=datetime.now(),

View File

@ -153,11 +153,7 @@ class VariablePool(BaseModel):
return None
node_id, name = self._selector_to_keys(selector)
node_map = self.variable_dictionary.get(node_id)
if node_map is None:
return None
segment: Segment | None = node_map.get(name)
segment: Segment | None = self.variable_dictionary[node_id].get(name)
if segment is None:
return None

View File

@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,priority_dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline} \
--prefetch-multiplier=1
elif [[ "${MODE}" == "beat" ]]; then

View File

@ -1,134 +0,0 @@
"""
Broadcast channel for Pub/Sub messaging.
"""
import types
from abc import abstractmethod
from collections.abc import Iterator
from contextlib import AbstractContextManager
from typing import Protocol, Self
class Subscription(AbstractContextManager["Subscription"], Protocol):
"""A subscription to a topic that provides an iterator over received messages.
The subscription can be used as a context manager and will automatically
close when exiting the context.
Note: `Subscription` instances are not thread-safe. Each thread should create its own
subscription.
"""
@abstractmethod
def __iter__(self) -> Iterator[bytes]:
"""`__iter__` returns an iterator used to consume the message from this subscription.
If the caller did not enter the context, `__iter__` may lazily perform the setup before
yielding messages; otherwise `__enter__` handles it.”
If the subscription is closed, then the returned iterator exits without
raising any error.
"""
...
@abstractmethod
def close(self) -> None:
"""close closes the subscription, releases any resources associated with it."""
...
def __enter__(self) -> Self:
"""`__enter__` does the setup logic of the subscription (if any), and return itself."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.close()
return None
@abstractmethod
def receive(self, timeout: float | None = 0.1) -> bytes | None:
"""Receive the next message from the broadcast channel.
If `timeout` is specified, this method returns `None` if no message is
received within the given period. If `timeout` is `None`, the call blocks
until a message is received.
Calling receive with `timeout=None` is highly discouraged, as it is impossible to
cancel a blocking subscription.
:param timeout: timeout for receive message, in seconds.
Returns:
bytes: The received message as a byte string, or
None: If the timeout expires before a message is received.
Raises:
SubscriptionClosed: If the subscription has already been closed.
"""
...
class Producer(Protocol):
"""Producer is an interface for message publishing. It is already bound to a specific topic.
`Producer` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def publish(self, payload: bytes) -> None:
"""Publish a message to the bounded topic."""
...
class Subscriber(Protocol):
"""Subscriber is an interface for subscription creation. It is already bound to a specific topic.
`Subscriber` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def subscribe(self) -> Subscription:
pass
class Topic(Producer, Subscriber, Protocol):
"""A named channel for publishing and subscribing to messages.
Topics provide both read and write access. For restricted access,
use as_producer() for write-only view or as_subscriber() for read-only view.
`Topic` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def as_producer(self) -> Producer:
"""as_producer creates a write-only view for this topic."""
...
@abstractmethod
def as_subscriber(self) -> Subscriber:
"""as_subscriber create a read-only view for this topic."""
...
class BroadcastChannel(Protocol):
"""A broadcasting channel is a channel supporting broadcasting semantics.
Each channel is identified by a topic, different topics are isolated and do not affect each other.
There can be multiple subscriptions to a specific topic. When a publisher publishes a message to
a specific topic, all subscription should receive the published message.
There are no restriction for the persistence of messages. Once a subscription is created, it
should receive all subsequent messages published.
`BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads.
"""
@abstractmethod
def topic(self, topic: str) -> "Topic":
"""topic returns a `Topic` instance for the given topic name."""
...

View File

@ -1,12 +0,0 @@
class BroadcastChannelError(Exception):
"""`BroadcastChannelError` is the base class for all exceptions related
to `BroadcastChannel`."""
pass
class SubscriptionClosedError(BroadcastChannelError):
"""SubscriptionClosedError means that the subscription has been closed and
methods for consuming messages should not be called."""
pass

View File

@ -1,3 +0,0 @@
from .channel import BroadcastChannel
__all__ = ["BroadcastChannel"]

View File

@ -1,200 +0,0 @@
import logging
import queue
import threading
import types
from collections.abc import Generator, Iterator
from typing import Self
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis
from redis.client import PubSub
_logger = logging.getLogger(__name__)
class BroadcastChannel:
"""
Redis Pub/Sub based broadcast channel implementation.
Provides "at most once" delivery semantics for messages published to channels.
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
"""
def __init__(
self,
redis_client: Redis,
):
self._client = redis_client
def topic(self, topic: str) -> "Topic":
return Topic(self._client, topic)
class Topic:
def __init__(self, redis_client: Redis, topic: str):
self._client = redis_client
self._topic = topic
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.publish(self._topic, payload)
def as_subscriber(self) -> Subscriber:
return self
def subscribe(self) -> Subscription:
return _RedisSubscription(
pubsub=self._client.pubsub(),
topic=self._topic,
)
class _RedisSubscription(Subscription):
def __init__(
self,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
self._dropped_count = 0
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
def _start_if_needed(self) -> None:
with self._start_lock:
if self._started:
return
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
if self._pubsub is None:
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
self._pubsub.subscribe(self._topic)
_logger.debug("Subscribed to channel %s", self._topic)
self._listener_thread = threading.Thread(
target=self._listen,
name=f"redis-broadcast-{self._topic}",
daemon=True,
)
self._listener_thread.start()
self._started = True
def _listen(self) -> None:
pubsub = self._pubsub
assert pubsub is not None, "PubSub should not be None while starting listening."
while not self._closed.is_set():
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
if raw_message is None:
continue
if raw_message.get("type") != "message":
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
continue
payload_bytes: bytes | None = raw_message.get("data")
if not isinstance(payload_bytes, bytes):
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
continue
self._enqueue_message(payload_bytes)
_logger.debug("Listener thread stopped for channel %s", self._topic)
pubsub.unsubscribe(self._topic)
pubsub.close()
_logger.debug("PubSub closed for topic %s", self._topic)
self._pubsub = None
def _enqueue_message(self, payload: bytes) -> None:
while not self._closed.is_set():
try:
self._queue.put_nowait(payload)
return
except queue.Full:
try:
self._queue.get_nowait()
self._dropped_count += 1
_logger.debug(
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
self._topic,
self._dropped_count,
)
except queue.Empty:
continue
return
def _message_iterator(self) -> Generator[bytes, None, None]:
while not self._closed.is_set():
try:
item = self._queue.get(timeout=0.1)
except queue.Empty:
continue
yield item
def __iter__(self) -> Iterator[bytes]:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
self._start_if_needed()
try:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
return item
def __enter__(self) -> Self:
self._start_if_needed()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.close()
return None
def close(self) -> None:
if self._closed.is_set():
return
self._closed.set()
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
# method should NOT be called concurrently.
#
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=1.0)
self._listener_thread = None

View File

@ -110,7 +110,7 @@ class Account(UserMixin, TypeBase):
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
role: TenantAccountRole | None = field(default=None, init=False)
@ -250,9 +250,7 @@ class Tenant(TypeBase):
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), init=False, onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), init=False)
def get_accounts(self) -> list[Account]:
return list(
@ -291,7 +289,7 @@ class TenantAccountJoin(TypeBase):
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
@ -312,7 +310,7 @@ class AccountIntegrate(TypeBase):
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False, onupdate=func.current_timestamp()
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
@ -398,5 +396,5 @@ class TenantPluginAutoUpgradeStrategy(TypeBase):
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False, onupdate=func.current_timestamp()
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@ -61,20 +61,18 @@ class Dataset(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
embedding_model = mapped_column(sa.String(255), nullable=True)
embedding_model_provider = mapped_column(sa.String(255), nullable=True)
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = mapped_column(db.String(255), nullable=True)
embedding_model_provider = mapped_column(db.String(255), nullable=True)
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=db.text("10"))
collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(JSONB, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
icon_info = mapped_column(JSONB, nullable=True)
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'::character varying"))
runtime_mode = mapped_column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(sa.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
chunk_structure = mapped_column(db.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=db.text("true"))
@property
def total_documents(self):
@ -401,9 +399,7 @@ class Document(Base):
archived_reason = mapped_column(String(255), nullable=True)
archived_by = mapped_column(StringUUID, nullable=True)
archived_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
doc_type = mapped_column(String(40), nullable=True)
doc_metadata = mapped_column(JSONB, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying"))
@ -720,9 +716,7 @@ class DocumentSegment(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(sa.Text, nullable=True)
@ -887,7 +881,7 @@ class ChildChunk(Base):
)
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), onupdate=func.current_timestamp()
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
@ -1042,8 +1036,8 @@ class TidbAuthBinding(Base):
tenant_id = mapped_column(StringUUID, nullable=True)
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
status = mapped_column(String(255), nullable=False, server_default=sa.text("'CREATING'::character varying"))
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying"))
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -1094,9 +1088,7 @@ class ExternalKnowledgeApis(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
def to_dict(self) -> dict[str, Any]:
return {
@ -1149,9 +1141,7 @@ class ExternalKnowledgeBindings(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(Base):
@ -1207,7 +1197,7 @@ class DatasetMetadata(Base):
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), onupdate=func.current_timestamp()
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
@ -1234,48 +1224,44 @@ class DatasetMetadataBinding(Base):
class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_built_in_templates"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
name = mapped_column(sa.String(255), nullable=False)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
name = mapped_column(db.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False)
chunk_structure = mapped_column(sa.String(255), nullable=False)
chunk_structure = mapped_column(db.String(255), nullable=False)
icon = mapped_column(sa.JSON, nullable=False)
yaml_content = mapped_column(sa.Text, nullable=False)
copyright = mapped_column(sa.String(255), nullable=False)
privacy_policy = mapped_column(sa.String(255), nullable=False)
copyright = mapped_column(db.String(255), nullable=False)
privacy_policy = mapped_column(db.String(255), nullable=False)
position = mapped_column(sa.Integer, nullable=False)
install_count = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(sa.String(255), nullable=False)
language = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_customized_templates"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
sa.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
db.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = mapped_column(StringUUID, nullable=False)
name = mapped_column(sa.String(255), nullable=False)
name = mapped_column(db.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False)
chunk_structure = mapped_column(sa.String(255), nullable=False)
chunk_structure = mapped_column(db.String(255), nullable=False)
icon = mapped_column(sa.JSON, nullable=False)
position = mapped_column(sa.Integer, nullable=False)
yaml_content = mapped_column(sa.Text, nullable=False)
install_count = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(sa.String(255), nullable=False)
language = mapped_column(db.String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def created_user_name(self):
@ -1287,21 +1273,19 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
class Pipeline(Base): # type: ignore[name-defined]
__tablename__ = "pipelines"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name = mapped_column(sa.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False, server_default=sa.text("''::character varying"))
name = mapped_column(db.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False, server_default=db.text("''::character varying"))
workflow_id = mapped_column(StringUUID, nullable=True)
is_public = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
is_published = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
is_public = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
is_published = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
def retrieve_dataset(self, session: Session):
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
@ -1310,16 +1294,16 @@ class Pipeline(Base): # type: ignore[name-defined]
class DocumentPipelineExecutionLog(Base):
__tablename__ = "document_pipeline_execution_logs"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
sa.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
pipeline_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
datasource_type = mapped_column(sa.String(255), nullable=False)
datasource_type = mapped_column(db.String(255), nullable=False)
datasource_info = mapped_column(sa.Text, nullable=False)
datasource_node_id = mapped_column(sa.String(255), nullable=False)
datasource_node_id = mapped_column(db.String(255), nullable=False)
input_data = mapped_column(sa.JSON, nullable=False)
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@ -1327,14 +1311,12 @@ class DocumentPipelineExecutionLog(Base):
class PipelineRecommendedPlugin(Base):
__tablename__ = "pipeline_recommended_plugins"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
plugin_id = mapped_column(sa.Text, nullable=False)
provider_name = mapped_column(sa.Text, nullable=False)
position = mapped_column(sa.Integer, nullable=False, default=0)
active = mapped_column(sa.Boolean, nullable=False, default=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -95,9 +95,7 @@ class App(Base):
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
@ -316,9 +314,7 @@ class AppModelConfig(Base):
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
opening_statement = mapped_column(sa.Text)
suggested_questions = mapped_column(sa.Text)
suggested_questions_after_answer = mapped_column(sa.Text)
@ -549,9 +545,7 @@ class RecommendedApp(Base):
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self) -> App | None:
@ -650,9 +644,7 @@ class Conversation(Base):
read_account_id = mapped_column(StringUUID)
dialogue_count: Mapped[int] = mapped_column(default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
message_annotations = db.relationship(
@ -956,9 +948,7 @@ class Message(Base):
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True)
@ -1306,9 +1296,7 @@ class MessageFeedback(Base):
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def from_account(self) -> Account | None:
@ -1390,9 +1378,7 @@ class MessageAnnotation(Base):
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
account_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def account(self):
@ -1457,9 +1443,7 @@ class AppAnnotationSetting(Base):
created_user_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_user_id = mapped_column(StringUUID, nullable=False)
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def collection_binding_detail(self):
@ -1487,9 +1471,7 @@ class OperationLog(Base):
content = mapped_column(sa.JSON)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class DefaultEndUserSessionID(StrEnum):
@ -1528,9 +1510,7 @@ class EndUser(Base, UserMixin):
session_id: Mapped[str] = mapped_column()
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class AppMCPServer(Base):
@ -1550,9 +1530,7 @@ class AppMCPServer(Base):
parameters = mapped_column(sa.Text, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_server_code(n: int) -> str:
@ -1598,9 +1576,7 @@ class Site(Base):
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
code = mapped_column(String(255))
@property
@ -1802,7 +1778,7 @@ class MessageAgentThought(Base):
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
total_price = mapped_column(sa.Numeric, nullable=True)
currency = mapped_column(String, nullable=True)
currency: Mapped[str | None] = mapped_column()
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)

View File

@ -1,66 +1,62 @@
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .engine import db
from .types import StringUUID
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
__tablename__ = "datasource_oauth_params"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
sa.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
system_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False)
class DatasourceProvider(Base):
__tablename__ = "datasource_providers"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
sa.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
auth_type: Mapped[str] = mapped_column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = mapped_column(JSONB, nullable=False)
avatar_url: Mapped[str] = mapped_column(sa.Text, nullable=True, default="default")
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1")
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
class DatasourceOauthTenantParamConfig(Base):
__tablename__ = "datasource_oauth_tenant_params"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
db.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(255), nullable=False)
client_params: Mapped[dict] = mapped_column(JSONB, nullable=False, default={})
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now)

View File

@ -72,9 +72,7 @@ class Provider(Base):
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
def __repr__(self):
return (
@ -137,9 +135,7 @@ class ProviderModel(Base):
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@cached_property
def credential(self):
@ -174,9 +170,7 @@ class TenantDefaultModel(Base):
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class TenantPreferredModelProvider(Base):
@ -191,9 +185,7 @@ class TenantPreferredModelProvider(Base):
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderOrder(Base):
@ -220,9 +212,7 @@ class ProviderOrder(Base):
pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime)
refunded_at: Mapped[datetime | None] = mapped_column(DateTime)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderModelSetting(Base):
@ -244,9 +234,7 @@ class ProviderModelSetting(Base):
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class LoadBalancingModelConfig(Base):
@ -271,9 +259,7 @@ class LoadBalancingModelConfig(Base):
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderCredential(Base):
@ -293,9 +279,7 @@ class ProviderCredential(Base):
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderModelCredential(Base):
@ -323,6 +307,4 @@ class ProviderModelCredential(Base):
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -140,9 +140,8 @@ class Workflow(Base):
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default=func.current_timestamp(),
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
default=naive_utc_now(),
server_onupdate=func.current_timestamp(),
)
_environment_variables: Mapped[str] = mapped_column(
"environment_variables", sa.Text, nullable=False, server_default="{}"
@ -151,7 +150,7 @@ class Workflow(Base):
"conversation_variables", sa.Text, nullable=False, server_default="{}"
)
_rag_pipeline_variables: Mapped[str] = mapped_column(
"rag_pipeline_variables", sa.Text, nullable=False, server_default="{}"
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
)
VERSION_DRAFT = "draft"

View File

@ -50,7 +50,6 @@ from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
from models.workflow import Workflow
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
@ -80,6 +79,7 @@ from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
@ -1694,7 +1694,7 @@ class DocumentService:
# trigger async task
if document_ids:
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)

View File

@ -1,83 +0,0 @@
import logging
from collections.abc import Callable, Sequence
from dataclasses import asdict
from functools import cached_property
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.feature_service import FeatureService
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
logger = logging.getLogger(__name__)
class DocumentIndexingTaskProxy:
def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
self._tenant_id = tenant_id
self._dataset_id = dataset_id
self._document_ids = document_ids
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
@cached_property
def features(self):
return FeatureService.get_features(self._tenant_id)
def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
logger.info("send dataset %s to direct queue", self._dataset_id)
task_func.delay( # type: ignore
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]):
logger.info("send dataset %s to tenant queue", self._dataset_id)
if self._tenant_isolated_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self._tenant_isolated_task_queue.push_tasks(
[
asdict(
DocumentTask(
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
)
]
)
logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids)
else:
# Set flag and execute task
self._tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids
)
logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids)
def _send_to_default_tenant_queue(self):
self._send_to_tenant_queue(normal_document_indexing_task)
def _send_to_priority_tenant_queue(self):
self._send_to_tenant_queue(priority_document_indexing_task)
def _send_to_priority_direct_queue(self):
self._send_to_direct_queue(priority_document_indexing_task)
def _dispatch(self):
logger.info(
"dispatch args: %s - %s - %s",
self._tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan,
)
# dispatch to different indexing queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
self._send_to_default_tenant_queue()
else:
# dispatch to priority pipeline queue with tenant self sub queue for other plans
self._send_to_priority_tenant_queue()
else:
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue()
def delay(self):
self._dispatch()

View File

@ -1,106 +0,0 @@
import json
import logging
from collections.abc import Callable, Sequence
from functools import cached_property
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from services.feature_service import FeatureService
from services.file_service import FileService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
class RagPipelineTaskProxy:
# Default uploaded file name for rag pipeline invoke entities
_RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME = "rag_pipeline_invoke_entities.json"
def __init__(
self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: Sequence[RagPipelineInvokeEntity]
):
self._dataset_tenant_id = dataset_tenant_id
self._user_id = user_id
self._rag_pipeline_invoke_entities = rag_pipeline_invoke_entities
self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(dataset_tenant_id, "pipeline")
@cached_property
def features(self):
return FeatureService.get_features(self._dataset_tenant_id)
def _upload_invoke_entities(self) -> str:
text = [item.model_dump() for item in self._rag_pipeline_invoke_entities]
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(
json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id
)
return upload_file.id
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
logger.info("send file %s to direct queue", upload_file_id)
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self._dataset_tenant_id,
)
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]):
logger.info("send file %s to tenant queue", upload_file_id)
if self._tenant_isolated_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self._tenant_isolated_task_queue.push_tasks([upload_file_id])
logger.info("push tasks: %s", upload_file_id)
else:
# Set flag and execute task
self._tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self._dataset_tenant_id,
)
logger.info("init tasks: %s", upload_file_id)
def _send_to_default_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
def _send_to_priority_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task)
def _send_to_priority_direct_queue(self, upload_file_id: str):
self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task)
def _dispatch(self):
upload_file_id = self._upload_invoke_entities()
if not upload_file_id:
raise ValueError("upload_file_id is empty")
logger.info(
"dispatch args: %s - %s - %s",
self._dataset_tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan,
)
# dispatch to different pipeline queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == CloudPlan.SANDBOX:
# dispatch to normal pipeline queue with tenant isolation for sandbox plan
self._send_to_default_tenant_queue(upload_file_id)
else:
# dispatch to priority pipeline queue with tenant isolation for other plans
self._send_to_priority_tenant_queue(upload_file_id)
else:
# dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue(upload_file_id)
def delay(self):
if not self._rag_pipeline_invoke_entities:
logger.warning(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
self._dataset_tenant_id,
self._user_id,
)
return
self._dispatch()

View File

@ -126,7 +126,7 @@ workflow:
type: mixed
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
plugin_id: langgenius/jina_datasource
provider_name: jinareader
provider_name: jina
provider_type: website_crawl
selected: false
title: Jina Reader

View File

@ -126,7 +126,7 @@ workflow:
type: mixed
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
plugin_id: langgenius/jina_datasource
provider_name: jinareader
provider_name: jina
provider_type: website_crawl
selected: false
title: Jina Reader

View File

@ -419,7 +419,7 @@ workflow:
type: mixed
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
plugin_id: langgenius/jina_datasource
provider_name: jinareader
provider_name: jina
provider_type: website_crawl
selected: false
title: Jina Reader

View File

@ -48,6 +48,7 @@ def add_document_to_index_task(dataset_document_id: str):
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == False,
DocumentSegment.status == "completed",
)
.order_by(DocumentSegment.position.asc())

View File

@ -1,14 +1,11 @@
import logging
import time
from collections.abc import Callable, Sequence
import click
from celery import shared_task
from configs import dify_config
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@ -25,24 +22,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
:param dataset_id:
:param document_ids:
.. warning:: TO BE DEPRECATED
This function will be deprecated and removed in a future version.
Use normal_document_indexing_task or priority_document_indexing_task instead.
Usage: document_indexing_task.delay(dataset_id, document_ids)
"""
logger.warning("document indexing legacy mode received: %s - %s", dataset_id, document_ids)
_document_indexing(dataset_id, document_ids)
def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
"""
Process document for tasks
:param dataset_id:
:param document_ids:
Usage: _document_indexing(dataset_id, document_ids)
"""
documents = []
start_at = time.perf_counter()
@ -106,63 +87,3 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
def _document_indexing_with_tenant_queue(
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
):
try:
_document_indexing(dataset_id, document_ids)
except Exception:
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
finally:
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
if next_tasks:
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=document_task.tenant_id,
dataset_id=document_task.dataset_id,
document_ids=document_task.document_ids,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()
@shared_task(queue="dataset")
def normal_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
"""
Async process document
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: normal_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("normal document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, normal_document_indexing_task)
@shared_task(queue="priority_dataset")
def priority_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]):
"""
Priority async process document
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: priority_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("priority document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, priority_document_indexing_task)

View File

@ -12,10 +12,8 @@ from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models import Account, Tenant
@ -24,8 +22,6 @@ from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
logger = logging.getLogger(__name__)
@shared_task(queue="priority_pipeline")
def priority_rag_pipeline_run_task(
@ -73,27 +69,6 @@ def priority_rag_pipeline_run_task(
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()

View File

@ -12,20 +12,17 @@ from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
logger = logging.getLogger(__name__)
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
@ -73,27 +70,26 @@ def rag_pipeline_run_task(
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "pipeline")
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY)
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
if next_file_id:
# Process the next waiting task
# Keep the flag set to indicate a task is running
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()
redis_client.delete(tenant_pipeline_task_key)
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()

View File

@ -25,12 +25,7 @@ import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import (
PauseStatePersistenceLayer,
WorkflowResumptionContext,
)
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus
@ -44,7 +39,7 @@ from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models import Account
from models import WorkflowPause as WorkflowPauseModel
from models.model import AppMode, UploadFile
from models.model import UploadFile
from models.workflow import Workflow, WorkflowRun
from services.file_service import FileService
from services.workflow_run_service import WorkflowRunService
@ -231,39 +226,11 @@ class TestPauseStatePersistenceLayerTestContainers:
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
def _create_generate_entity(
self,
workflow_execution_id: str | None = None,
user_id: str | None = None,
workflow_id: str | None = None,
) -> WorkflowAppGenerateEntity:
execution_id = workflow_execution_id or getattr(self, "test_workflow_run_id", str(uuid.uuid4()))
wf_id = workflow_id or getattr(self, "test_workflow_id", str(uuid.uuid4()))
tenant_id = getattr(self, "test_tenant_id", "tenant-123")
app_id = getattr(self, "test_app_id", "app-123")
app_config = WorkflowUIBasedAppConfig(
tenant_id=str(tenant_id),
app_id=str(app_id),
app_mode=AppMode.WORKFLOW,
workflow_id=str(wf_id),
)
return WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user_id or getattr(self, "test_user_id", str(uuid.uuid4())),
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id=execution_id,
)
def _create_pause_state_persistence_layer(
self,
workflow_run: WorkflowRun | None = None,
workflow: Workflow | None = None,
state_owner_user_id: str | None = None,
generate_entity: WorkflowAppGenerateEntity | None = None,
) -> PauseStatePersistenceLayer:
"""Create PauseStatePersistenceLayer with real dependencies."""
owner_id = state_owner_user_id
@ -277,23 +244,10 @@ class TestPauseStatePersistenceLayerTestContainers:
assert owner_id is not None
owner_id = str(owner_id)
workflow_execution_id = (
workflow_run.id if workflow_run is not None else getattr(self, "test_workflow_run_id", None)
)
assert workflow_execution_id is not None
workflow_id = workflow.id if workflow is not None else getattr(self, "test_workflow_id", None)
assert workflow_id is not None
entity_user_id = getattr(self, "test_user_id", owner_id)
entity = generate_entity or self._create_generate_entity(
workflow_execution_id=str(workflow_execution_id),
user_id=entity_user_id,
workflow_id=str(workflow_id),
)
return PauseStatePersistenceLayer(
session_factory=self.session.get_bind(),
state_owner_user_id=owner_id,
generate_entity=entity,
)
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
@ -343,15 +297,10 @@ class TestPauseStatePersistenceLayerTestContainers:
assert pause_model.resumed_at is None
storage_content = storage.load(pause_model.state_object_key).decode()
resumption_context = WorkflowResumptionContext.loads(storage_content)
assert resumption_context.version == "1"
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
expected_state = json.loads(graph_runtime_state.dumps())
actual_state = json.loads(resumption_context.serialized_graph_runtime_state)
actual_state = json.loads(storage_content)
assert actual_state == expected_state
persisted_entity = resumption_context.get_generate_entity()
assert isinstance(persisted_entity, WorkflowAppGenerateEntity)
assert persisted_entity.workflow_execution_id == self.test_workflow_run_id
def test_state_persistence_and_retrieval(self, db_session_with_containers):
"""Test that pause state can be persisted and retrieved correctly."""
@ -392,15 +341,13 @@ class TestPauseStatePersistenceLayerTestContainers:
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
state_bytes = pause_entity.get_state()
resumption_context = WorkflowResumptionContext.loads(state_bytes.decode())
retrieved_state = json.loads(resumption_context.serialized_graph_runtime_state)
retrieved_state = json.loads(state_bytes.decode())
expected_state = json.loads(graph_runtime_state.dumps())
assert retrieved_state == expected_state
assert retrieved_state["outputs"] == complex_outputs
assert retrieved_state["total_tokens"] == 250
assert retrieved_state["node_run_steps"] == 10
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
def test_database_transaction_handling(self, db_session_with_containers):
"""Test that database transactions are handled correctly."""
@ -463,9 +410,7 @@ class TestPauseStatePersistenceLayerTestContainers:
# Verify content in storage
storage_content = storage.load(pause_model.state_object_key).decode()
resumption_context = WorkflowResumptionContext.loads(storage_content)
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
assert storage_content == graph_runtime_state.dumps()
def test_workflow_with_different_creators(self, db_session_with_containers):
"""Test pause state with workflows created by different users."""
@ -529,8 +474,6 @@ class TestPauseStatePersistenceLayerTestContainers:
# Verify the state owner is the workflow creator
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
assert pause_entity is not None
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
"""Test that layer ignores non-pause events."""

View File

@ -1,595 +0,0 @@
"""
Integration tests for TenantIsolatedTaskQueue using testcontainers.
These tests verify the Redis-based task queue functionality with real Redis instances,
testing tenant isolation, task serialization, and queue operations in a realistic environment.
Includes compatibility tests for migrating from legacy string-only queues.
All tests use generic naming to avoid coupling to specific business implementations.
"""
import time
from dataclasses import dataclass
from typing import Any
from uuid import uuid4
import pytest
from faker import Faker
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@dataclass
class TestTask:
"""Test task data structure for testing complex object serialization."""
task_id: str
tenant_id: str
data: dict[str, Any]
metadata: dict[str, Any]
class TestTenantIsolatedTaskQueueIntegration:
"""Integration tests for TenantIsolatedTaskQueue using testcontainers."""
@pytest.fixture
def fake(self):
"""Faker instance for generating test data."""
return Faker()
@pytest.fixture
def test_tenant_and_account(self, db_session_with_containers, fake):
"""Create test tenant and account for testing."""
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return tenant, account
@pytest.fixture
def test_queue(self, test_tenant_and_account):
"""Create a generic test queue for testing."""
tenant, _ = test_tenant_and_account
return TenantIsolatedTaskQueue(tenant.id, "test_queue")
@pytest.fixture
def secondary_queue(self, test_tenant_and_account):
"""Create a secondary test queue for testing isolation."""
tenant, _ = test_tenant_and_account
return TenantIsolatedTaskQueue(tenant.id, "secondary_queue")
def test_queue_initialization(self, test_tenant_and_account):
"""Test queue initialization with correct key generation."""
tenant, _ = test_tenant_and_account
queue = TenantIsolatedTaskQueue(tenant.id, "test-key")
assert queue._tenant_id == tenant.id
assert queue._unique_key == "test-key"
assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}"
assert queue._task_key == f"tenant_test-key_task:{tenant.id}"
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake):
"""Test that different tenants have isolated queues."""
tenant1, _ = test_tenant_and_account
# Create second tenant
tenant2 = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant2)
db_session_with_containers.commit()
queue1 = TenantIsolatedTaskQueue(tenant1.id, "same-key")
queue2 = TenantIsolatedTaskQueue(tenant2.id, "same-key")
assert queue1._queue != queue2._queue
assert queue1._task_key != queue2._task_key
assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
def test_key_isolation(self, test_tenant_and_account):
"""Test that different keys have isolated queues."""
tenant, _ = test_tenant_and_account
queue1 = TenantIsolatedTaskQueue(tenant.id, "key1")
queue2 = TenantIsolatedTaskQueue(tenant.id, "key2")
assert queue1._queue != queue2._queue
assert queue1._task_key != queue2._task_key
assert queue1._queue == f"tenant_self_key1_task_queue:{tenant.id}"
assert queue2._queue == f"tenant_self_key2_task_queue:{tenant.id}"
def test_task_key_operations(self, test_queue):
"""Test task key operations (get, set, delete)."""
# Initially no task key should exist
assert test_queue.get_task_key() is None
# Set task waiting time with default TTL
test_queue.set_task_waiting_time()
task_key = test_queue.get_task_key()
# Redis returns bytes, convert to string for comparison
assert task_key in (b"1", "1")
# Set task waiting time with custom TTL
custom_ttl = 30
test_queue.set_task_waiting_time(custom_ttl)
task_key = test_queue.get_task_key()
assert task_key in (b"1", "1")
# Delete task key
test_queue.delete_task_key()
assert test_queue.get_task_key() is None
def test_push_and_pull_string_tasks(self, test_queue):
"""Test pushing and pulling string tasks."""
tasks = ["task1", "task2", "task3"]
# Push tasks
test_queue.push_tasks(tasks)
# Pull tasks (FIFO order)
pulled_tasks = test_queue.pull_tasks(3)
# Should get tasks in FIFO order (lpush + rpop = FIFO)
assert pulled_tasks == ["task1", "task2", "task3"]
def test_push_and_pull_multiple_tasks(self, test_queue):
"""Test pushing and pulling multiple tasks at once."""
tasks = ["task1", "task2", "task3", "task4", "task5"]
# Push tasks
test_queue.push_tasks(tasks)
# Pull multiple tasks
pulled_tasks = test_queue.pull_tasks(3)
assert len(pulled_tasks) == 3
assert pulled_tasks == ["task1", "task2", "task3"]
# Pull remaining tasks
remaining_tasks = test_queue.pull_tasks(5)
assert len(remaining_tasks) == 2
assert remaining_tasks == ["task4", "task5"]
def test_push_and_pull_complex_objects(self, test_queue, fake):
"""Test pushing and pulling complex object tasks."""
# Create complex task objects as dictionaries (not dataclass instances)
tasks = [
{
"task_id": str(uuid4()),
"tenant_id": test_queue._tenant_id,
"data": {
"file_id": str(uuid4()),
"content": fake.text(),
"metadata": {"size": fake.random_int(1000, 10000)},
},
"metadata": {"created_at": fake.iso8601(), "tags": fake.words(3)},
},
{
"task_id": str(uuid4()),
"tenant_id": test_queue._tenant_id,
"data": {
"file_id": str(uuid4()),
"content": "测试中文内容",
"metadata": {"size": fake.random_int(1000, 10000)},
},
"metadata": {"created_at": fake.iso8601(), "tags": ["中文", "测试", "emoji🚀"]},
},
]
# Push complex tasks
test_queue.push_tasks(tasks)
# Pull tasks
pulled_tasks = test_queue.pull_tasks(2)
assert len(pulled_tasks) == 2
# Verify deserialized tasks match original (FIFO order)
for i, pulled_task in enumerate(pulled_tasks):
original_task = tasks[i] # FIFO order
assert isinstance(pulled_task, dict)
assert pulled_task["task_id"] == original_task["task_id"]
assert pulled_task["tenant_id"] == original_task["tenant_id"]
assert pulled_task["data"] == original_task["data"]
assert pulled_task["metadata"] == original_task["metadata"]
def test_mixed_task_types(self, test_queue, fake):
"""Test pushing and pulling mixed string and object tasks."""
string_task = "simple_string_task"
object_task = {
"task_id": str(uuid4()),
"dataset_id": str(uuid4()),
"document_ids": [str(uuid4()) for _ in range(3)],
}
tasks = [string_task, object_task, "another_string"]
# Push mixed tasks
test_queue.push_tasks(tasks)
# Pull all tasks
pulled_tasks = test_queue.pull_tasks(3)
assert len(pulled_tasks) == 3
# Verify types and content
assert pulled_tasks[0] == string_task
assert isinstance(pulled_tasks[1], dict)
assert pulled_tasks[1] == object_task
assert pulled_tasks[2] == "another_string"
def test_empty_queue_operations(self, test_queue):
"""Test operations on empty queue."""
# Pull from empty queue
tasks = test_queue.pull_tasks(5)
assert tasks == []
# Pull zero or negative count
assert test_queue.pull_tasks(0) == []
assert test_queue.pull_tasks(-1) == []
def test_task_ttl_expiration(self, test_queue):
"""Test task key TTL expiration."""
# Set task with short TTL
short_ttl = 2
test_queue.set_task_waiting_time(short_ttl)
# Verify task key exists
assert test_queue.get_task_key() == b"1" or test_queue.get_task_key() == "1"
# Wait for TTL to expire
time.sleep(short_ttl + 1)
# Verify task key has expired
assert test_queue.get_task_key() is None
def test_large_task_batch(self, test_queue, fake):
"""Test handling large batches of tasks."""
# Create large batch of tasks
large_batch = []
for i in range(100):
task = {
"task_id": str(uuid4()),
"index": i,
"data": fake.text(max_nb_chars=100),
"metadata": {"batch_id": str(uuid4())},
}
large_batch.append(task)
# Push large batch
test_queue.push_tasks(large_batch)
# Pull all tasks
pulled_tasks = test_queue.pull_tasks(100)
assert len(pulled_tasks) == 100
# Verify all tasks were retrieved correctly (FIFO order)
for i, task in enumerate(pulled_tasks):
assert isinstance(task, dict)
assert task["index"] == i # FIFO order
def test_queue_operations_isolation(self, test_tenant_and_account, fake):
"""Test concurrent operations on different queues."""
tenant, _ = test_tenant_and_account
# Create multiple queues for the same tenant
queue1 = TenantIsolatedTaskQueue(tenant.id, "queue1")
queue2 = TenantIsolatedTaskQueue(tenant.id, "queue2")
# Push tasks to different queues
queue1.push_tasks(["task1_queue1", "task2_queue1"])
queue2.push_tasks(["task1_queue2", "task2_queue2"])
# Verify queues are isolated
tasks1 = queue1.pull_tasks(2)
tasks2 = queue2.pull_tasks(2)
assert tasks1 == ["task1_queue1", "task2_queue1"]
assert tasks2 == ["task1_queue2", "task2_queue2"]
assert tasks1 != tasks2
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake):
"""Test TaskWrapper serialization and deserialization roundtrip."""
# Create complex nested data
complex_data = {
"id": str(uuid4()),
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5], "unicode": "测试中文", "emoji": "🚀"}},
"metadata": {"created_at": fake.iso8601(), "tags": ["tag1", "tag2", "tag3"]},
}
# Create wrapper and serialize
wrapper = TaskWrapper(data=complex_data)
serialized = wrapper.serialize()
# Verify serialization
assert isinstance(serialized, str)
assert "测试中文" in serialized
assert "🚀" in serialized
# Deserialize and verify
deserialized_wrapper = TaskWrapper.deserialize(serialized)
assert deserialized_wrapper.data == complex_data
def test_error_handling_invalid_json(self, test_queue):
"""Test error handling for invalid JSON in wrapped tasks."""
# Manually create invalid JSON task (not a valid TaskWrapper JSON)
invalid_json_task = "invalid json data"
# Push invalid task directly to Redis
redis_client.lpush(test_queue._queue, invalid_json_task)
# Pull task - should fall back to string since it's not valid JSON
task = test_queue.pull_tasks(1)
assert task[0] == invalid_json_task
def test_real_world_batch_processing_scenario(self, test_queue, fake):
"""Test realistic batch processing scenario."""
# Simulate batch processing tasks
batch_tasks = []
for i in range(3):
task = {
"file_id": str(uuid4()),
"tenant_id": test_queue._tenant_id,
"user_id": str(uuid4()),
"processing_config": {
"model": fake.random_element(["model_a", "model_b", "model_c"]),
"temperature": fake.random.uniform(0.1, 1.0),
"max_tokens": fake.random_int(1000, 4000),
},
"metadata": {
"source": fake.random_element(["upload", "api", "webhook"]),
"priority": fake.random_element(["low", "normal", "high"]),
},
}
batch_tasks.append(task)
# Push tasks
test_queue.push_tasks(batch_tasks)
# Process tasks in batches
batch_size = 2
processed_tasks = []
while True:
batch = test_queue.pull_tasks(batch_size)
if not batch:
break
processed_tasks.extend(batch)
# Verify all tasks were processed
assert len(processed_tasks) == 3
# Verify task structure
for task in processed_tasks:
assert isinstance(task, dict)
assert "file_id" in task
assert "tenant_id" in task
assert "processing_config" in task
assert "metadata" in task
assert task["tenant_id"] == test_queue._tenant_id
class TestTenantIsolatedTaskQueueCompatibility:
"""Compatibility tests for migrating from legacy string-only queues."""
@pytest.fixture
def fake(self):
"""Faker instance for generating test data."""
return Faker()
@pytest.fixture
def test_tenant_and_account(self, db_session_with_containers, fake):
"""Create test tenant and account for testing."""
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return tenant, account
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake):
"""
Test compatibility with legacy queues containing only string data.
This simulates the scenario where Redis queues already contain string data
from the old architecture, and we need to ensure the new code can read them.
"""
tenant, _ = test_tenant_and_account
queue = TenantIsolatedTaskQueue(tenant.id, "legacy_queue")
# Simulate legacy string data in Redis queue (using old format)
legacy_strings = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
# Manually push legacy strings directly to Redis (simulating old system)
for legacy_string in legacy_strings:
redis_client.lpush(queue._queue, legacy_string)
# Verify new code can read legacy string data
pulled_tasks = queue.pull_tasks(5)
assert len(pulled_tasks) == 5
# Verify all tasks are strings (not wrapped)
for task in pulled_tasks:
assert isinstance(task, str)
assert task.startswith("legacy_task_")
# Verify order (FIFO from Redis list)
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
assert pulled_tasks == expected_order
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake):
"""
Test complete migration scenario from legacy to new system.
This simulates the real-world scenario where:
1. Legacy system has string data in Redis
2. New system starts processing the same queue
3. Both legacy and new tasks coexist during migration
4. New system can handle both formats seamlessly
"""
tenant, _ = test_tenant_and_account
queue = TenantIsolatedTaskQueue(tenant.id, "migration_queue")
# Phase 1: Legacy system has data
legacy_tasks = [f"legacy_resource_{i}" for i in range(1, 6)]
redis_client.lpush(queue._queue, *legacy_tasks)
# Phase 2: New system starts processing legacy data
processed_legacy = []
while True:
tasks = queue.pull_tasks(1)
if not tasks:
break
processed_legacy.extend(tasks)
# Verify legacy data was processed correctly
assert len(processed_legacy) == 5
for task in processed_legacy:
assert isinstance(task, str)
assert task.startswith("legacy_resource_")
# Phase 3: New system adds new tasks (mixed types)
new_string_tasks = ["new_resource_1", "new_resource_2"]
new_object_tasks = [
{
"resource_id": str(uuid4()),
"tenant_id": tenant.id,
"processing_type": "new_system",
"metadata": {"version": "2.0", "features": ["ai", "ml"]},
},
{
"resource_id": str(uuid4()),
"tenant_id": tenant.id,
"processing_type": "new_system",
"metadata": {"version": "2.0", "features": ["ai", "ml"]},
},
]
# Push new tasks using new system
queue.push_tasks(new_string_tasks)
queue.push_tasks(new_object_tasks)
# Phase 4: Process all new tasks
processed_new = []
while True:
tasks = queue.pull_tasks(1)
if not tasks:
break
processed_new.extend(tasks)
# Verify new tasks were processed correctly
assert len(processed_new) == 4
string_tasks = [task for task in processed_new if isinstance(task, str)]
object_tasks = [task for task in processed_new if isinstance(task, dict)]
assert len(string_tasks) == 2
assert len(object_tasks) == 2
# Verify string tasks
for task in string_tasks:
assert task.startswith("new_resource_")
# Verify object tasks
for task in object_tasks:
assert isinstance(task, dict)
assert "resource_id" in task
assert "tenant_id" in task
assert task["tenant_id"] == tenant.id
assert task["processing_type"] == "new_system"
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake):
"""
Test error recovery when legacy queue contains malformed data.
This ensures the new system can gracefully handle corrupted or
malformed legacy data without crashing.
"""
tenant, _ = test_tenant_and_account
queue = TenantIsolatedTaskQueue(tenant.id, "error_recovery_queue")
# Create mix of valid and malformed legacy data
mixed_legacy_data = [
"valid_legacy_task_1",
"valid_legacy_task_2",
"malformed_data_string", # This should be treated as string
"valid_legacy_task_3",
"invalid_json_not_taskwrapper_format", # This should fall back to string (not valid TaskWrapper JSON)
"valid_legacy_task_4",
]
# Manually push mixed data directly to Redis
redis_client.lpush(queue._queue, *mixed_legacy_data)
# Process all tasks
processed_tasks = []
while True:
tasks = queue.pull_tasks(1)
if not tasks:
break
processed_tasks.extend(tasks)
# Verify all tasks were processed (no crashes)
assert len(processed_tasks) == 6
# Verify all tasks are strings (malformed data falls back to string)
for task in processed_tasks:
assert isinstance(task, str)
# Verify valid tasks are preserved
valid_tasks = [task for task in processed_tasks if task.startswith("valid_legacy_task_")]
assert len(valid_tasks) == 4
# Verify malformed data is handled gracefully
malformed_tasks = [task for task in processed_tasks if not task.startswith("valid_legacy_task_")]
assert len(malformed_tasks) == 2
assert "malformed_data_string" in malformed_tasks
assert "invalid_json_not_taskwrapper_format" in malformed_tasks

View File

@ -1,311 +0,0 @@
"""
Integration tests for Redis broadcast channel implementation using TestContainers.
This test suite covers real Redis interactions including:
- Multiple producer/consumer scenarios
- Network failure scenarios
- Performance under load
- Real-world usage patterns
"""
import threading
import time
import uuid
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
import redis
from testcontainers.redis import RedisContainer
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
from libs.broadcast_channel.exc import SubscriptionClosedError
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
class TestRedisBroadcastChannelIntegration:
"""Integration tests for Redis broadcast channel with real Redis instance."""
@pytest.fixture(scope="class")
def redis_container(self) -> Iterator[RedisContainer]:
"""Create a Redis container for integration testing."""
with RedisContainer(image="redis:6-alpine") as container:
yield container
@pytest.fixture(scope="class")
def redis_client(self, redis_container: RedisContainer) -> redis.Redis:
"""Create a Redis client connected to the test container."""
host = redis_container.get_container_host_ip()
port = redis_container.get_exposed_port(6379)
return redis.Redis(host=host, port=port, decode_responses=False)
@pytest.fixture
def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel:
"""Create a BroadcastChannel instance with real Redis client."""
return RedisBroadcastChannel(redis_client)
@classmethod
def _get_test_topic_name(cls):
return f"test_topic_{uuid.uuid4()}"
# ==================== Basic Functionality Tests ===================='
def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel):
topic_name = self._get_test_topic_name()
topic = broadcast_channel.topic(topic_name)
subscription = topic.subscribe()
consuming_event = threading.Event()
def consume():
msgs = []
consuming_event.set()
for msg in subscription:
msgs.append(msg)
return msgs
with ThreadPoolExecutor(max_workers=1) as executor:
producer_future = executor.submit(consume)
consuming_event.wait()
subscription.close()
msgs = producer_future.result(timeout=1)
assert msgs == []
def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel):
"""Test complete end-to-end messaging flow."""
topic_name = "test-topic"
message = b"hello world"
# Create producer and subscriber
topic = broadcast_channel.topic(topic_name)
producer = topic.as_producer()
subscription = topic.subscribe()
# Publish and receive message
def producer_thread():
time.sleep(0.1) # Small delay to ensure subscriber is ready
producer.publish(message)
time.sleep(0.1)
subscription.close()
def consumer_thread() -> list[bytes]:
received_messages = []
for msg in subscription:
received_messages.append(msg)
return received_messages
# Run producer and consumer
with ThreadPoolExecutor(max_workers=2) as executor:
producer_future = executor.submit(producer_thread)
consumer_future = executor.submit(consumer_thread)
# Wait for completion
producer_future.result(timeout=5.0)
received_messages = consumer_future.result(timeout=5.0)
assert len(received_messages) == 1
assert received_messages[0] == message
def test_multiple_subscribers_same_topic(self, broadcast_channel: BroadcastChannel):
"""Test message broadcasting to multiple subscribers."""
topic_name = "broadcast-topic"
message = b"broadcast message"
subscriber_count = 5
# Create producer and multiple subscribers
topic = broadcast_channel.topic(topic_name)
producer = topic.as_producer()
subscriptions = [topic.subscribe() for _ in range(subscriber_count)]
def producer_thread():
time.sleep(0.2) # Allow all subscribers to connect
producer.publish(message)
time.sleep(0.2)
for sub in subscriptions:
sub.close()
def consumer_thread(subscription: Subscription) -> list[bytes]:
received_msgs = []
while True:
try:
msg = subscription.receive(0.1)
except SubscriptionClosedError:
break
if msg is None:
continue
received_msgs.append(msg)
if len(received_msgs) >= 1:
break
return received_msgs
# Run producer and consumers
with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor:
producer_future = executor.submit(producer_thread)
consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions]
# Wait for completion
producer_future.result(timeout=10.0)
msgs_by_consumers = []
for future in as_completed(consumer_futures, timeout=10.0):
msgs_by_consumers.append(future.result())
# Close all subscriptions
for subscription in subscriptions:
subscription.close()
# Verify all subscribers received the message
for msgs in msgs_by_consumers:
assert len(msgs) == 1
assert msgs[0] == message
def test_topic_isolation(self, broadcast_channel: BroadcastChannel):
"""Test that different topics are isolated from each other."""
topic1_name = "topic1"
topic2_name = "topic2"
message1 = b"message for topic1"
message2 = b"message for topic2"
# Create producers and subscribers for different topics
topic1 = broadcast_channel.topic(topic1_name)
topic2 = broadcast_channel.topic(topic2_name)
def producer_thread():
time.sleep(0.1)
topic1.publish(message1)
topic2.publish(message2)
def consumer_by_thread(topic: Topic) -> list[bytes]:
subscription = topic.subscribe()
received = []
with subscription:
for msg in subscription:
received.append(msg)
if len(received) >= 1:
break
return received
# Run all threads
with ThreadPoolExecutor(max_workers=3) as executor:
producer_future = executor.submit(producer_thread)
consumer1_future = executor.submit(consumer_by_thread, topic1)
consumer2_future = executor.submit(consumer_by_thread, topic2)
# Wait for completion
producer_future.result(timeout=5.0)
received_by_topic1 = consumer1_future.result(timeout=5.0)
received_by_topic2 = consumer2_future.result(timeout=5.0)
# Verify topic isolation
assert len(received_by_topic1) == 1
assert len(received_by_topic2) == 1
assert received_by_topic1[0] == message1
assert received_by_topic2[0] == message2
# ==================== Performance Tests ====================
def test_concurrent_producers(self, broadcast_channel: BroadcastChannel):
"""Test multiple producers publishing to the same topic."""
topic_name = "concurrent-producers-topic"
producer_count = 5
messages_per_producer = 5
topic = broadcast_channel.topic(topic_name)
subscription = topic.subscribe()
expected_total = producer_count * messages_per_producer
consumer_ready = threading.Event()
def producer_thread(producer_idx: int) -> set[bytes]:
producer = topic.as_producer()
produced = set()
for i in range(messages_per_producer):
message = f"producer_{producer_idx}_msg_{i}".encode()
produced.add(message)
producer.publish(message)
time.sleep(0.001) # Small delay to avoid overwhelming
return produced
def consumer_thread() -> set[bytes]:
received_msgs: set[bytes] = set()
with subscription:
consumer_ready.set()
while True:
try:
msg = subscription.receive(timeout=0.1)
except SubscriptionClosedError:
break
if msg is None:
if len(received_msgs) >= expected_total:
break
else:
continue
received_msgs.add(msg)
return received_msgs
# Run producers and consumer
with ThreadPoolExecutor(max_workers=producer_count + 1) as executor:
consumer_future = executor.submit(consumer_thread)
consumer_ready.wait()
producer_futures = [executor.submit(producer_thread, i) for i in range(producer_count)]
sent_msgs: set[bytes] = set()
# Wait for completion
for future in as_completed(producer_futures, timeout=30.0):
sent_msgs.update(future.result())
subscription.close()
consumer_received_msgs = consumer_future.result(timeout=30.0)
# Verify message content
assert sent_msgs == consumer_received_msgs
# ==================== Resource Management Tests ====================
def test_subscription_cleanup(self, broadcast_channel: BroadcastChannel, redis_client: redis.Redis):
"""Test proper cleanup of subscription resources."""
topic_name = "cleanup-test-topic"
# Create multiple subscriptions
topic = broadcast_channel.topic(topic_name)
def _consume(sub: Subscription):
for i in sub:
pass
subscriptions = []
for i in range(5):
subscription = topic.subscribe()
subscriptions.append(subscription)
# Start all subscriptions
thread = threading.Thread(target=_consume, args=(subscription,))
thread.start()
time.sleep(0.01)
# Verify subscriptions are active
pubsub_info = redis_client.pubsub_numsub(topic_name)
# pubsub_numsub returns list of tuples, find our topic
topic_subscribers = 0
for channel, count in pubsub_info:
# the channel name returned by redis is bytes.
if channel == topic_name.encode():
topic_subscribers = count
break
assert topic_subscribers >= 5
# Close all subscriptions
for subscription in subscriptions:
subscription.close()
# Wait a bit for cleanup
time.sleep(1)
# Verify subscriptions are cleaned up
pubsub_info_after = redis_client.pubsub_numsub(topic_name)
topic_subscribers_after = 0
for channel, count in pubsub_info_after:
if channel == topic_name.encode():
topic_subscribers_after = count
break
assert topic_subscribers_after == 0

View File

@ -256,7 +256,7 @@ class TestAddDocumentToIndexTask:
"""
# Arrange: Use non-existent document ID
fake = Faker()
non_existent_id = str(fake.uuid4())
non_existent_id = fake.uuid4()
# Act: Execute the task with non-existent document
add_document_to_index_task(non_existent_id)
@ -282,7 +282,7 @@ class TestAddDocumentToIndexTask:
- Redis cache key not affected
"""
# Arrange: Create test data with invalid indexing status
_, document = self._create_test_dataset_and_document(
dataset, document = self._create_test_dataset_and_document(
db_session_with_containers, mock_external_service_dependencies
)
@ -417,15 +417,15 @@ class TestAddDocumentToIndexTask:
# Verify redis cache was cleared
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_with_already_enabled_segments(
def test_add_document_to_index_with_no_segments_to_process(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test document indexing when segments are already enabled.
Test document indexing when no segments need processing.
This test verifies:
- Segments with status="completed" are processed regardless of enabled status
- Index processing occurs with all completed segments
- Proper handling when all segments are already enabled
- Index processing still occurs but with empty documents list
- Auto disable log deletion still occurs
- Redis cache is cleared
"""
@ -465,16 +465,15 @@ class TestAddDocumentToIndexTask:
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify index processing occurred with all completed segments
# Assert: Verify index processing occurred but with empty documents list
mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX)
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify the load method was called with all completed segments
# (implementation doesn't filter by enabled status, only by status="completed")
# Verify the load method was called with empty documents list
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
assert len(documents) == 3 # All completed segments are processed
assert len(documents) == 0 # No segments to process
# Verify redis cache was cleared
assert redis_client.exists(indexing_cache_key) == 0
@ -500,7 +499,7 @@ class TestAddDocumentToIndexTask:
# Create some auto disable log entries
fake = Faker()
auto_disable_logs = []
for _ in range(2):
for i in range(2):
log_entry = DatasetAutoDisableLog(
id=fake.uuid4(),
tenant_id=document.tenant_id,
@ -596,11 +595,9 @@ class TestAddDocumentToIndexTask:
Test segment filtering with various edge cases.
This test verifies:
- Only segments with status="completed" are processed (regardless of enabled status)
- Segments with status!="completed" are NOT processed
- Only segments with enabled=False and status="completed" are processed
- Segments are ordered by position correctly
- Mixed segment states are handled properly
- All segments are updated to enabled=True after processing
- Redis cache key deletion
"""
# Arrange: Create test data
@ -631,8 +628,7 @@ class TestAddDocumentToIndexTask:
db.session.add(segment1)
segments.append(segment1)
# Segment 2: Should be processed (enabled=True, status="completed")
# Note: Implementation doesn't filter by enabled status, only by status="completed"
# Segment 2: Should NOT be processed (enabled=True, status="completed")
segment2 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
@ -644,7 +640,7 @@ class TestAddDocumentToIndexTask:
tokens=len(fake.text(max_nb_chars=200).split()) * 2,
index_node_id="node_1",
index_node_hash="hash_1",
enabled=True, # Already enabled, but will still be processed
enabled=True, # Already enabled
status="completed",
created_by=document.created_by,
)
@ -706,14 +702,11 @@ class TestAddDocumentToIndexTask:
call_args = mock_external_service_dependencies["index_processor"].load.call_args
assert call_args is not None
documents = call_args[0][1] # Second argument should be documents list
assert len(documents) == 3 # 3 segments with status="completed" should be processed
assert len(documents) == 2 # Only 2 segments should be processed
# Verify correct segments were processed (by position order)
# Segments 1, 2, 4 should be processed (positions 0, 1, 3)
# Segment 3 is skipped (position 2, status="processing")
assert documents[0].metadata["doc_id"] == "node_0" # segment1, position 0
assert documents[1].metadata["doc_id"] == "node_1" # segment2, position 1
assert documents[2].metadata["doc_id"] == "node_3" # segment4, position 3
assert documents[0].metadata["doc_id"] == "node_0" # position 0
assert documents[1].metadata["doc_id"] == "node_3" # position 3
# Verify database state changes
db.session.refresh(document)
@ -724,7 +717,7 @@ class TestAddDocumentToIndexTask:
# All segments should be enabled because the task updates ALL segments for the document
assert segment1.enabled is True
assert segment2.enabled is True # Was already enabled, stays True
assert segment2.enabled is True # Was already enabled, now updated to True
assert segment3.enabled is True # Was not processed but still updated to True
assert segment4.enabled is True

View File

@ -1,33 +1,17 @@
from dataclasses import asdict
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.entities.document_task import DocumentTask
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from tasks.document_indexing_task import (
_document_indexing, # Core function
_document_indexing_with_tenant_queue, # Tenant queue wrapper function
document_indexing_task, # Deprecated old interface
normal_document_indexing_task, # New normal task
priority_document_indexing_task, # New priority task
)
from tasks.document_indexing_task import document_indexing_task
class TestDocumentIndexingTasks:
"""Integration tests for document indexing tasks using testcontainers.
This test class covers:
- Core _document_indexing function
- Deprecated document_indexing_task function
- New normal_document_indexing_task function
- New priority_document_indexing_task function
- Tenant queue wrapper _document_indexing_with_tenant_queue function
"""
class TestDocumentIndexingTask:
"""Integration tests for document_indexing_task using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
@ -240,7 +224,7 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in documents]
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
@ -248,11 +232,10 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with correct documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
@ -278,7 +261,7 @@ class TestDocumentIndexingTasks:
document_ids = [fake.uuid4() for _ in range(3)]
# Act: Execute the task with non-existent dataset
_document_indexing(non_existent_dataset_id, document_ids)
document_indexing_task(non_existent_dataset_id, document_ids)
# Assert: Verify no processing occurred
mock_external_service_dependencies["indexing_runner"].assert_not_called()
@ -308,18 +291,17 @@ class TestDocumentIndexingTasks:
all_document_ids = existing_document_ids + non_existent_document_ids
# Act: Execute the task with mixed document IDs
_document_indexing(dataset.id, all_document_ids)
document_indexing_task(dataset.id, all_document_ids)
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only existing documents were updated
# Re-query documents from database since _document_indexing uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with only existing documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
@ -351,7 +333,7 @@ class TestDocumentIndexingTasks:
)
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
@ -359,11 +341,10 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
def test_document_indexing_task_mixed_document_states(
self, db_session_with_containers, mock_external_service_dependencies
@ -426,18 +407,17 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with mixed document states
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify all documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in all_documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with all documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
@ -490,16 +470,15 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify error handling
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error
assert updated_document.stopped_at is not None
for document in all_documents:
db.session.refresh(document)
assert document.indexing_status == "error"
assert document.error is not None
assert "batch upload" in document.error
assert document.stopped_at is not None
# Verify no indexing runner was called
mock_external_service_dependencies["indexing_runner"].assert_not_called()
@ -524,18 +503,17 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in documents]
# Act: Execute the task with billing disabled
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify successful processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
def test_document_indexing_task_document_is_paused_error(
self, db_session_with_containers, mock_external_service_dependencies
@ -563,7 +541,7 @@ class TestDocumentIndexingTasks:
)
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
@ -571,317 +549,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# ==================== NEW TESTS FOR REFACTORED FUNCTIONS ====================
def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test document_indexing_task basic functionality.
This test verifies:
- Task function calls the wrapper correctly
- Basic parameter passing works
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
# Act: Execute the deprecated task (it only takes 2 parameters)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_normal_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test normal_document_indexing_task basic functionality.
This test verifies:
- Task function calls the wrapper correctly
- Basic parameter passing works
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
# Act: Execute the new normal task
normal_document_indexing_task(tenant_id, dataset.id, document_ids)
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_priority_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test priority_document_indexing_task basic functionality.
This test verifies:
- Task function calls the wrapper correctly
- Basic parameter passing works
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
# Act: Execute the new priority task
priority_document_indexing_task(tenant_id, dataset.id, document_ids)
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_document_indexing_with_tenant_queue_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test _document_indexing_with_tenant_queue function with no waiting tasks.
This test verifies:
- Core indexing logic execution (same as _document_indexing)
- Tenant queue cleanup when no waiting tasks
- Task function parameter passing
- Queue management after processing
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Assert: Verify core processing occurred (same as _document_indexing)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated (same as _document_indexing)
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# Verify the run method was called with correct documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0]
assert len(processed_documents) == 2
# Verify task function was not called (no waiting tasks)
mock_task_func.delay.assert_not_called()
def test_document_indexing_with_tenant_queue_with_waiting_tasks(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis.
This test verifies:
- Core indexing logic execution
- Real Redis-based tenant queue processing of waiting tasks
- Task function calls for waiting tasks
- Queue management with multiple tasks using actual Redis operations
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
dataset_id = dataset.id
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Use real Redis for TenantIsolatedTaskQueue
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
# Create real queue instance
queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
# Add waiting tasks to the real Redis queue
waiting_tasks = [
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]),
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-2"]),
]
# Convert DocumentTask objects to dictionaries for serialization
waiting_task_dicts = [asdict(task) for task in waiting_tasks]
queue.push_tasks(waiting_task_dicts)
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Assert: Verify core processing occurred
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify task function was called for each waiting task
assert mock_task_func.delay.call_count == 1
# Verify correct parameters for each call
calls = mock_task_func.delay.call_args_list
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
# Verify queue is empty after processing (tasks were pulled)
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
assert len(remaining_tasks) == 1
def test_document_indexing_with_tenant_queue_error_handling(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error handling in _document_indexing_with_tenant_queue using real Redis.
This test verifies:
- Exception handling during core processing
- Tenant queue cleanup even on errors using real Redis
- Proper error logging
- Function completes without raising exceptions
- Queue management continues despite core processing errors
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
dataset_id = dataset.id
# Mock IndexingRunner to raise an exception
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception("Test error")
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Use real Redis for TenantIsolatedTaskQueue
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
# Create real queue instance
queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing")
# Add waiting task to the real Redis queue
waiting_task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"])
queue.push_tasks([asdict(waiting_task)])
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# Verify waiting task was still processed despite core processing error
mock_task_func.delay.assert_called_once()
# Verify correct parameters for the call
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_document_indexing_with_tenant_queue_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant isolation in _document_indexing_with_tenant_queue using real Redis.
This test verifies:
- Different tenants have isolated queues
- Tasks from one tenant don't affect another tenant's queue
- Queue operations are properly scoped to tenant
"""
# Arrange: Create test data for two different tenants
dataset1, documents1 = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
dataset2, documents2 = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
tenant1_id = dataset1.tenant_id
tenant2_id = dataset2.tenant_id
dataset1_id = dataset1.id
dataset2_id = dataset2.id
document_ids1 = [doc.id for doc in documents1]
document_ids2 = [doc.id for doc in documents2]
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Use real Redis for TenantIsolatedTaskQueue
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
# Create queue instances for both tenants
queue1 = TenantIsolatedTaskQueue(tenant1_id, "document_indexing")
queue2 = TenantIsolatedTaskQueue(tenant2_id, "document_indexing")
# Add waiting tasks to both queues
waiting_task1 = DocumentTask(tenant_id=tenant1_id, dataset_id=dataset1.id, document_ids=["tenant1-doc-1"])
waiting_task2 = DocumentTask(tenant_id=tenant2_id, dataset_id=dataset2.id, document_ids=["tenant2-doc-1"])
queue1.push_tasks([asdict(waiting_task1)])
queue2.push_tasks([asdict(waiting_task2)])
# Act: Execute the wrapper function for tenant1 only
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
# Assert: Verify core processing occurred for tenant1
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only tenant1's waiting task was processed
mock_task_func.delay.assert_called_once()
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
assert len(remaining_tasks1) == 0
# Verify tenant2's queue still has its task (isolation)
remaining_tasks2 = queue2.pull_tasks(count=10)
assert len(remaining_tasks2) == 1
# Verify queue keys are different
assert queue1._queue != queue2._queue
assert queue1._task_key != queue2._task_key
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None

View File

@ -1,936 +0,0 @@
import json
import uuid
from unittest.mock import patch
import pytest
from faker import Faker
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Pipeline
from models.workflow import Workflow
from tasks.rag_pipeline.priority_rag_pipeline_run_task import (
priority_rag_pipeline_run_task,
run_single_rag_pipeline_task,
)
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
class TestRagPipelineRunTasks:
"""Integration tests for RAG pipeline run tasks using testcontainers.
This test class covers:
- priority_rag_pipeline_run_task function
- rag_pipeline_run_task function
- run_single_rag_pipeline_task function
- Real Redis-based TenantIsolatedTaskQueue operations
- PipelineGenerator._generate method mocking and parameter validation
- File operations and cleanup
- Error handling and queue management
"""
@pytest.fixture
def mock_pipeline_generator(self):
"""Mock PipelineGenerator._generate method."""
with patch("core.app.apps.pipeline.pipeline_generator.PipelineGenerator._generate") as mock_generate:
# Mock the _generate method to return a simple response
mock_generate.return_value = {"answer": "Test response", "metadata": {"test": "data"}}
yield mock_generate
@pytest.fixture
def mock_file_service(self):
"""Mock FileService for file operations."""
with (
patch("services.file_service.FileService.get_file_content") as mock_get_content,
patch("services.file_service.FileService.delete_file") as mock_delete_file,
):
yield {
"get_content": mock_get_content,
"delete_file": mock_delete_file,
}
def _create_test_pipeline_and_workflow(self, db_session_with_containers):
"""
Helper method to create test pipeline and workflow for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
Returns:
tuple: (account, tenant, pipeline, workflow) - Created entities
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
app_id=str(uuid.uuid4()),
type="workflow",
version="draft",
graph="{}",
features="{}",
marked_name=fake.company(),
marked_comment=fake.text(max_nb_chars=100),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
db.session.add(workflow)
db.session.commit()
# Create pipeline
pipeline = Pipeline(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
workflow_id=workflow.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
created_by=account.id,
)
db.session.add(pipeline)
db.session.commit()
# Refresh entities to ensure they're properly loaded
db.session.refresh(account)
db.session.refresh(tenant)
db.session.refresh(workflow)
db.session.refresh(pipeline)
return account, tenant, pipeline, workflow
def _create_rag_pipeline_invoke_entities(self, account, tenant, pipeline, workflow, count=2):
"""
Helper method to create RAG pipeline invoke entities for testing.
Args:
account: Account instance
tenant: Tenant instance
pipeline: Pipeline instance
workflow: Workflow instance
count: Number of entities to create
Returns:
list: List of RagPipelineInvokeEntity instances
"""
fake = Faker()
entities = []
for i in range(count):
# Create application generate entity
app_config = {
"app_id": str(uuid.uuid4()),
"app_name": fake.company(),
"mode": "workflow",
"workflow_id": workflow.id,
"tenant_id": tenant.id,
"app_mode": "workflow",
}
application_generate_entity = {
"task_id": str(uuid.uuid4()),
"app_config": app_config,
"inputs": {"query": f"Test query {i}"},
"files": [],
"user_id": account.id,
"stream": False,
"invoke_from": "published",
"workflow_execution_id": str(uuid.uuid4()),
"pipeline_config": {
"app_id": str(uuid.uuid4()),
"app_name": fake.company(),
"mode": "workflow",
"workflow_id": workflow.id,
"tenant_id": tenant.id,
"app_mode": "workflow",
},
"datasource_type": "upload_file",
"datasource_info": {},
"dataset_id": str(uuid.uuid4()),
"batch": "test_batch",
}
entity = RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
application_generate_entity=application_generate_entity,
user_id=account.id,
tenant_id=tenant.id,
workflow_id=workflow.id,
streaming=False,
workflow_execution_id=str(uuid.uuid4()),
workflow_thread_pool_id=str(uuid.uuid4()),
)
entities.append(entity)
return entities
def _create_file_content_for_entities(self, entities):
"""
Helper method to create file content for RAG pipeline invoke entities.
Args:
entities: List of RagPipelineInvokeEntity instances
Returns:
str: JSON string containing serialized entities
"""
entities_data = [entity.model_dump() for entity in entities]
return json.dumps(entities_data)
def test_priority_rag_pipeline_run_task_success(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test successful priority RAG pipeline run task execution.
This test verifies:
- Task execution with multiple RAG pipeline invoke entities
- File content retrieval and parsing
- PipelineGenerator._generate method calls with correct parameters
- Thread pool execution
- File cleanup after execution
- Queue management with no waiting tasks
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=2)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Act: Execute the priority task
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify expected outcomes
# Verify file operations
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
# Verify PipelineGenerator._generate was called for each entity
assert mock_pipeline_generator.call_count == 2
# Verify call parameters for each entity
calls = mock_pipeline_generator.call_args_list
for call in calls:
call_kwargs = call[1] # Get keyword arguments
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_rag_pipeline_run_task_success(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test successful regular RAG pipeline run task execution.
This test verifies:
- Task execution with multiple RAG pipeline invoke entities
- File content retrieval and parsing
- PipelineGenerator._generate method calls with correct parameters
- Thread pool execution
- File cleanup after execution
- Queue management with no waiting tasks
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=3)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Act: Execute the regular task
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify expected outcomes
# Verify file operations
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
# Verify PipelineGenerator._generate was called for each entity
assert mock_pipeline_generator.call_count == 3
# Verify call parameters for each entity
calls = mock_pipeline_generator.call_args_list
for call in calls:
call_kwargs = call[1] # Get keyword arguments
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_priority_rag_pipeline_run_task_with_waiting_tasks(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test priority RAG pipeline run task with waiting tasks in queue using real Redis.
This test verifies:
- Core task execution
- Real Redis-based tenant queue processing of waiting tasks
- Task function calls for waiting tasks
- Queue management with multiple tasks using actual Redis operations
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Use real Redis for TenantIsolatedTaskQueue
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
# Add waiting tasks to the real Redis queue
waiting_file_ids = [str(uuid.uuid4()) for _ in range(2)]
queue.push_tasks(waiting_file_ids)
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act: Execute the priority task
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify core processing occurred
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining
def test_rag_pipeline_run_task_legacy_compatibility(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility.
This test simulates the scenario where:
- Old code writes file IDs directly to Redis list using lpush
- New worker processes these legacy queue entries
- Ensures backward compatibility during deployment transition
Legacy format: redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
New format: TenantIsolatedTaskQueue.push_tasks([file_id])
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Simulate legacy Redis queue format - direct file IDs in Redis list
from extensions.ext_redis import redis_client
# Legacy queue key format (old code)
legacy_queue_key = f"tenant_self_pipeline_task_queue:{tenant.id}"
legacy_task_key = f"tenant_pipeline_task:{tenant.id}"
# Add legacy format data to Redis (simulating old code behavior)
legacy_file_ids = [str(uuid.uuid4()) for _ in range(3)]
for file_id_legacy in legacy_file_ids:
redis_client.lpush(legacy_queue_key, file_id_legacy)
# Set the task key to indicate there are waiting tasks (legacy behavior)
redis_client.set(legacy_task_key, 1, ex=60 * 60)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the priority task with new code but legacy queue data
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify core processing occurred
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify that new code can process legacy queue entries
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
# Cleanup: Remove legacy test data
redis_client.delete(legacy_queue_key)
redis_client.delete(legacy_task_key)
def test_rag_pipeline_run_task_with_waiting_tasks(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with waiting tasks in queue using real Redis.
This test verifies:
- Core task execution
- Real Redis-based tenant queue processing of waiting tasks
- Task function calls for waiting tasks
- Queue management with multiple tasks using actual Redis operations
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Use real Redis for TenantIsolatedTaskQueue
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
# Add waiting tasks to the real Redis queue
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
queue.push_tasks(waiting_file_ids)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the regular task
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify core processing occurred
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
def test_priority_rag_pipeline_run_task_error_handling(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test error handling in priority RAG pipeline run task using real Redis.
This test verifies:
- Exception handling during core processing
- Tenant queue cleanup even on errors using real Redis
- Proper error logging
- Function completes without raising exceptions
- Queue management continues despite core processing errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Mock PipelineGenerator to raise an exception
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
# Use real Redis for TenantIsolatedTaskQueue
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act: Execute the priority task (should not raise exception)
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting task was still processed despite core processing error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_rag_pipeline_run_task_error_handling(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test error handling in regular RAG pipeline run task using real Redis.
This test verifies:
- Exception handling during core processing
- Tenant queue cleanup even on errors using real Redis
- Proper error logging
- Function completes without raising exceptions
- Queue management continues despite core processing errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Mock PipelineGenerator to raise an exception
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
# Use real Redis for TenantIsolatedTaskQueue
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the regular task (should not raise exception)
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting task was still processed despite core processing error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_priority_rag_pipeline_run_task_tenant_isolation(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test tenant isolation in priority RAG pipeline run task using real Redis.
This test verifies:
- Different tenants have isolated queues
- Tasks from one tenant don't affect another tenant's queue
- Queue operations are properly scoped to tenant
"""
# Arrange: Create test data for two different tenants
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
file_content1 = self._create_file_content_for_entities(entities1)
file_content2 = self._create_file_content_for_entities(entities2)
# Mock file service
file_id1 = str(uuid.uuid4())
file_id2 = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
# Use real Redis for TenantIsolatedTaskQueue
queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline")
queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline")
# Add waiting tasks to both queues
waiting_file_id1 = str(uuid.uuid4())
waiting_file_id2 = str(uuid.uuid4())
queue1.push_tasks([waiting_file_id1])
queue2.push_tasks([waiting_file_id2])
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act: Execute the priority task for tenant1 only
priority_rag_pipeline_run_task(file_id1, tenant1.id)
# Assert: Verify core processing occurred for tenant1
assert mock_file_service["get_content"].call_count == 1
assert mock_file_service["delete_file"].call_count == 1
assert mock_pipeline_generator.call_count == 1
# Verify only tenant1's waiting task was processed
mock_delay.assert_called_once()
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
assert call_kwargs.get("tenant_id") == tenant1.id
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
assert len(remaining_tasks1) == 0
# Verify tenant2's queue still has its task (isolation)
remaining_tasks2 = queue2.pull_tasks(count=10)
assert len(remaining_tasks2) == 1
# Verify queue keys are different
assert queue1._queue != queue2._queue
assert queue1._task_key != queue2._task_key
def test_rag_pipeline_run_task_tenant_isolation(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test tenant isolation in regular RAG pipeline run task using real Redis.
This test verifies:
- Different tenants have isolated queues
- Tasks from one tenant don't affect another tenant's queue
- Queue operations are properly scoped to tenant
"""
# Arrange: Create test data for two different tenants
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
file_content1 = self._create_file_content_for_entities(entities1)
file_content2 = self._create_file_content_for_entities(entities2)
# Mock file service
file_id1 = str(uuid.uuid4())
file_id2 = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
# Use real Redis for TenantIsolatedTaskQueue
queue1 = TenantIsolatedTaskQueue(tenant1.id, "pipeline")
queue2 = TenantIsolatedTaskQueue(tenant2.id, "pipeline")
# Add waiting tasks to both queues
waiting_file_id1 = str(uuid.uuid4())
waiting_file_id2 = str(uuid.uuid4())
queue1.push_tasks([waiting_file_id1])
queue2.push_tasks([waiting_file_id2])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the regular task for tenant1 only
rag_pipeline_run_task(file_id1, tenant1.id)
# Assert: Verify core processing occurred for tenant1
assert mock_file_service["get_content"].call_count == 1
assert mock_file_service["delete_file"].call_count == 1
assert mock_pipeline_generator.call_count == 1
# Verify only tenant1's waiting task was processed
mock_delay.assert_called_once()
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
assert call_kwargs.get("tenant_id") == tenant1.id
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
assert len(remaining_tasks1) == 0
# Verify tenant2's queue still has its task (isolation)
remaining_tasks2 = queue2.pull_tasks(count=10)
assert len(remaining_tasks2) == 1
# Verify queue keys are different
assert queue1._queue != queue2._queue
assert queue1._task_key != queue2._task_key
def test_run_single_rag_pipeline_task_success(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
):
"""
Test successful run_single_rag_pipeline_task execution.
This test verifies:
- Single RAG pipeline task execution within Flask app context
- Entity validation and database queries
- PipelineGenerator._generate method call with correct parameters
- Proper Flask context handling
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
entity_data = entities[0].model_dump()
# Act: Execute the single task
with flask_app_with_containers.app_context():
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
# Assert: Verify expected outcomes
# Verify PipelineGenerator._generate was called
assert mock_pipeline_generator.call_count == 1
# Verify call parameters
call = mock_pipeline_generator.call_args
call_kwargs = call[1] # Get keyword arguments
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_run_single_rag_pipeline_task_entity_validation_error(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
):
"""
Test run_single_rag_pipeline_task with invalid entity data.
This test verifies:
- Proper error handling for invalid entity data
- Exception logging
- Function raises ValueError for missing entities
"""
# Arrange: Create entity data with valid UUIDs but non-existent entities
fake = Faker()
invalid_entity_data = {
"pipeline_id": str(uuid.uuid4()),
"application_generate_entity": {
"app_config": {
"app_id": str(uuid.uuid4()),
"app_name": "Test App",
"mode": "workflow",
"workflow_id": str(uuid.uuid4()),
},
"inputs": {"query": "Test query"},
"query": "Test query",
"response_mode": "blocking",
"user": str(uuid.uuid4()),
"files": [],
"conversation_id": str(uuid.uuid4()),
},
"user_id": str(uuid.uuid4()),
"tenant_id": str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4()),
"streaming": False,
"workflow_execution_id": str(uuid.uuid4()),
"workflow_thread_pool_id": str(uuid.uuid4()),
}
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Account .* not found"):
run_single_rag_pipeline_task(invalid_entity_data, flask_app_with_containers)
# Assert: Pipeline generator should not be called
mock_pipeline_generator.assert_not_called()
def test_run_single_rag_pipeline_task_database_entity_not_found(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
):
"""
Test run_single_rag_pipeline_task with non-existent database entities.
This test verifies:
- Proper error handling for missing database entities
- Exception logging
- Function raises ValueError for missing entities
"""
# Arrange: Create test data with non-existent IDs
fake = Faker()
entity_data = {
"pipeline_id": str(uuid.uuid4()),
"application_generate_entity": {
"app_config": {
"app_id": str(uuid.uuid4()),
"app_name": "Test App",
"mode": "workflow",
"workflow_id": str(uuid.uuid4()),
},
"inputs": {"query": "Test query"},
"query": "Test query",
"response_mode": "blocking",
"user": str(uuid.uuid4()),
"files": [],
"conversation_id": str(uuid.uuid4()),
},
"user_id": str(uuid.uuid4()),
"tenant_id": str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4()),
"streaming": False,
"workflow_execution_id": str(uuid.uuid4()),
"workflow_thread_pool_id": str(uuid.uuid4()),
}
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Account .* not found"):
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
# Assert: Pipeline generator should not be called
mock_pipeline_generator.assert_not_called()
def test_priority_rag_pipeline_run_task_file_not_found(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test priority RAG pipeline run task with non-existent file.
This test verifies:
- Proper error handling for missing files
- Exception logging
- Function raises Exception for file errors
- Queue management continues despite file errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
# Mock file service to raise exception
file_id = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = Exception("File not found")
# Use real Redis for TenantIsolatedTaskQueue
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act & Assert: Execute the priority task (should raise Exception)
with pytest.raises(Exception, match="File not found"):
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_pipeline_generator.assert_not_called()
# Verify waiting task was still processed despite file error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_rag_pipeline_run_task_file_not_found(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with non-existent file.
This test verifies:
- Proper error handling for missing files
- Exception logging
- Function raises Exception for file errors
- Queue management continues despite file errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
# Mock file service to raise exception
file_id = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = Exception("File not found")
# Use real Redis for TenantIsolatedTaskQueue
queue = TenantIsolatedTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act & Assert: Execute the regular task (should raise Exception)
with pytest.raises(Exception, match="File not found"):
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_pipeline_generator.assert_not_called()
# Verify waiting task was still processed despite file error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0

View File

@ -4,14 +4,7 @@ from unittest.mock import Mock
import pytest
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import (
PauseStatePersistenceLayer,
WorkflowResumptionContext,
_AdvancedChatAppGenerateEntityWrapper,
_WorkflowGenerateEntityWrapper,
)
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
from core.variables.segments import Segment
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
@ -22,7 +15,6 @@ from core.workflow.graph_events.graph import (
GraphRunSucceededEvent,
)
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
@ -178,25 +170,6 @@ class MockCommandChannel:
class TestPauseStatePersistenceLayer:
"""Unit tests for PauseStatePersistenceLayer."""
@staticmethod
def _create_generate_entity(workflow_execution_id: str = "run-123") -> WorkflowAppGenerateEntity:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-123",
app_id="app-123",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-123",
)
return WorkflowAppGenerateEntity(
task_id="task-123",
app_config=app_config,
inputs={},
files=[],
user_id="user-123",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id=workflow_execution_id,
)
def test_init_with_dependency_injection(self):
session_factory = Mock(name="session_factory")
state_owner_user_id = "user-123"
@ -204,7 +177,6 @@ class TestPauseStatePersistenceLayer:
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id=state_owner_user_id,
generate_entity=self._create_generate_entity(),
)
assert layer._session_maker is session_factory
@ -214,11 +186,7 @@ class TestPauseStatePersistenceLayer:
def test_initialize_sets_dependencies(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner",
generate_entity=self._create_generate_entity(),
)
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
graph_runtime_state = MockReadOnlyGraphRuntimeState()
command_channel = MockCommandChannel()
@ -230,12 +198,7 @@ class TestPauseStatePersistenceLayer:
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
generate_entity = self._create_generate_entity(workflow_execution_id="run-123")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=generate_entity,
)
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
@ -258,20 +221,12 @@ class TestPauseStatePersistenceLayer:
mock_repo.create_workflow_pause.assert_called_once_with(
workflow_run_id="run-123",
state_owner_user_id="owner-123",
state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
state=expected_state,
)
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
resumption_context = WorkflowResumptionContext.loads(serialized_state)
assert resumption_context.serialized_graph_runtime_state == expected_state
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
@ -295,11 +250,7 @@ class TestPauseStatePersistenceLayer:
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
event = TestDataFactory.create_graph_run_paused_event()
@ -308,11 +259,7 @@ class TestPauseStatePersistenceLayer:
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id="owner-123",
generate_entity=self._create_generate_entity(),
)
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
@ -329,82 +276,3 @@ class TestPauseStatePersistenceLayer:
mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called()
def _build_workflow_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
"""Create a WorkflowAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-roundtrip",
app_id="app-roundtrip",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-roundtrip",
)
serialized_state = json.dumps({"state": "workflow"})
return WorkflowResumptionContext(
serialized_graph_runtime_state=serialized_state,
generate_entity=_WorkflowGenerateEntityWrapper(
entity=WorkflowAppGenerateEntity(
task_id="workflow-task",
app_config=app_config,
inputs={"input_key": "input_value"},
files=[],
user_id="user-roundtrip",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id="workflow-exec-roundtrip",
)
),
)
def _build_advanced_chat_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
"""Create an AdvancedChatAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-advanced",
app_id="app-advanced",
app_mode=AppMode.ADVANCED_CHAT,
workflow_id="workflow-advanced",
)
serialized_state = json.dumps({"state": "workflow"})
return WorkflowResumptionContext(
serialized_graph_runtime_state=serialized_state,
generate_entity=_AdvancedChatAppGenerateEntityWrapper(
entity=AdvancedChatAppGenerateEntity(
task_id="advanced-task",
app_config=app_config,
inputs={"topic": "roundtrip"},
files=[],
user_id="advanced-user",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_run_id="advanced-run-id",
query="Explain serialization behavior",
)
),
)
@pytest.mark.parametrize(
"state",
[
pytest.param(
_build_advanced_chat_generate_entity_for_roundtrip(),
id="advanced_chat",
),
pytest.param(
_build_workflow_generate_entity_for_roundtrip(),
id="workflow",
),
],
)
def test_workflow_resumption_context_dumps_loads_roundtrip(state: WorkflowResumptionContext):
"""WorkflowResumptionContext roundtrip preserves workflow generate entity metadata."""
dumped = state.dumps()
loaded = WorkflowResumptionContext.loads(dumped)
assert loaded == state
assert loaded.serialized_graph_runtime_state == state.serialized_graph_runtime_state
restored_entity = loaded.get_generate_entity()
assert isinstance(restored_entity, type(state.generate_entity.entity))

View File

@ -1,12 +0,0 @@
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
def test_get_runner_script():
code = JavascriptCodeProvider.get_default_code()
inputs = {"arg1": "hello, ", "arg2": "world!"}
script = NodeJsTemplateTransformer.assemble_runner_script(code, inputs)
script_lines = script.splitlines()
code_lines = code.splitlines()
# Check that the first lines of script are exactly the same as code
assert script_lines[: len(code_lines)] == code_lines

View File

@ -1,12 +0,0 @@
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
def test_get_runner_script():
code = Python3CodeProvider.get_default_code()
inputs = {"arg1": "hello, ", "arg2": "world!"}
script = Python3TemplateTransformer.assemble_runner_script(code, inputs)
script_lines = script.splitlines()
code_lines = code.splitlines()
# Check that the first lines of script are exactly the same as code
assert script_lines[: len(code_lines)] == code_lines

View File

@ -1,301 +0,0 @@
"""
Unit tests for TenantIsolatedTaskQueue.
These tests verify the Redis-based task queue functionality for tenant-specific
task management with proper serialization and deserialization.
"""
import json
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from pydantic import ValidationError
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
class TestTaskWrapper:
"""Test cases for TaskWrapper serialization/deserialization."""
def test_serialize_simple_data(self):
"""Test serialization of simple data types."""
data = {"key": "value", "number": 42, "list": [1, 2, 3]}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
assert isinstance(serialized, str)
# Verify it's valid JSON
parsed = json.loads(serialized)
assert parsed["data"] == data
def test_serialize_complex_data(self):
"""Test serialization of complex nested data."""
data = {
"nested": {"deep": {"value": "test", "numbers": [1, 2, 3, 4, 5]}},
"unicode": "测试中文",
"special_chars": "!@#$%^&*()",
}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
parsed = json.loads(serialized)
assert parsed["data"] == data
def test_deserialize_valid_data(self):
"""Test deserialization of valid JSON data."""
original_data = {"key": "value", "number": 42}
# Serialize using TaskWrapper to get the correct format
wrapper = TaskWrapper(data=original_data)
serialized = wrapper.serialize()
wrapper = TaskWrapper.deserialize(serialized)
assert wrapper.data == original_data
def test_deserialize_invalid_json(self):
"""Test deserialization handles invalid JSON gracefully."""
invalid_json = "{invalid json}"
# Pydantic will raise ValidationError for invalid JSON
with pytest.raises(ValidationError):
TaskWrapper.deserialize(invalid_json)
def test_serialize_ensure_ascii_false(self):
"""Test that serialization preserves Unicode characters."""
data = {"chinese": "中文测试", "emoji": "🚀"}
wrapper = TaskWrapper(data=data)
serialized = wrapper.serialize()
assert "中文测试" in serialized
assert "🚀" in serialized
class TestTenantIsolatedTaskQueue:
"""Test cases for TenantIsolatedTaskQueue functionality."""
@pytest.fixture
def mock_redis_client(self):
"""Mock Redis client for testing."""
mock_redis = MagicMock()
return mock_redis
@pytest.fixture
def sample_queue(self, mock_redis_client):
"""Create a sample TenantIsolatedTaskQueue instance."""
return TenantIsolatedTaskQueue("tenant-123", "test-key")
def test_initialization(self, sample_queue):
"""Test queue initialization with correct key generation."""
assert sample_queue._tenant_id == "tenant-123"
assert sample_queue._unique_key == "test-key"
assert sample_queue._queue == "tenant_self_test-key_task_queue:tenant-123"
assert sample_queue._task_key == "tenant_test-key_task:tenant-123"
@patch("core.rag.pipeline.queue.redis_client")
def test_get_task_key_exists(self, mock_redis, sample_queue):
"""Test getting task key when it exists."""
mock_redis.get.return_value = "1"
result = sample_queue.get_task_key()
assert result == "1"
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_get_task_key_not_exists(self, mock_redis, sample_queue):
"""Test getting task key when it doesn't exist."""
mock_redis.get.return_value = None
result = sample_queue.get_task_key()
assert result is None
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with default TTL."""
sample_queue.set_task_waiting_time()
mock_redis.setex.assert_called_once_with(
"tenant_test-key_task:tenant-123",
3600, # DEFAULT_TASK_TTL
1,
)
@patch("core.rag.pipeline.queue.redis_client")
def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with custom TTL."""
custom_ttl = 1800
sample_queue.set_task_waiting_time(custom_ttl)
mock_redis.setex.assert_called_once_with("tenant_test-key_task:tenant-123", custom_ttl, 1)
@patch("core.rag.pipeline.queue.redis_client")
def test_delete_task_key(self, mock_redis, sample_queue):
"""Test deleting task key."""
sample_queue.delete_task_key()
mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_string_list(self, mock_redis, sample_queue):
"""Test pushing string tasks directly."""
tasks = ["task1", "task2", "task3"]
sample_queue.push_tasks(tasks)
mock_redis.lpush.assert_called_once_with(
"tenant_self_test-key_task_queue:tenant-123", "task1", "task2", "task3"
)
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_mixed_types(self, mock_redis, sample_queue):
"""Test pushing mixed string and object tasks."""
tasks = ["string_task", {"object_task": "data", "id": 123}, "another_string"]
sample_queue.push_tasks(tasks)
# Verify lpush was called
mock_redis.lpush.assert_called_once()
call_args = mock_redis.lpush.call_args
# Check queue name
assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123"
# Check serialized tasks
serialized_tasks = call_args[0][1:]
assert len(serialized_tasks) == 3
assert serialized_tasks[0] == "string_task"
assert serialized_tasks[2] == "another_string"
# Check object task is serialized as TaskWrapper JSON (without prefix)
# It should be a valid JSON string that can be deserialized by TaskWrapper
wrapper = TaskWrapper.deserialize(serialized_tasks[1])
assert wrapper.data == {"object_task": "data", "id": 123}
@patch("core.rag.pipeline.queue.redis_client")
def test_push_tasks_empty_list(self, mock_redis, sample_queue):
"""Test pushing empty task list."""
sample_queue.push_tasks([])
mock_redis.lpush.assert_called_once_with("tenant_self_test-key_task_queue:tenant-123")
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_default_count(self, mock_redis, sample_queue):
"""Test pulling tasks with default count (1)."""
mock_redis.rpop.side_effect = ["task1", None]
result = sample_queue.pull_tasks()
assert result == ["task1"]
assert mock_redis.rpop.call_count == 1
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_custom_count(self, mock_redis, sample_queue):
"""Test pulling tasks with custom count."""
# First test: pull 3 tasks
mock_redis.rpop.side_effect = ["task1", "task2", "task3", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2", "task3"]
assert mock_redis.rpop.call_count == 3
# Reset mock for second test
mock_redis.reset_mock()
mock_redis.rpop.side_effect = ["task1", "task2", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2"]
assert mock_redis.rpop.call_count == 3
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_zero_count(self, mock_redis, sample_queue):
"""Test pulling tasks with zero count returns empty list."""
result = sample_queue.pull_tasks(0)
assert result == []
mock_redis.rpop.assert_not_called()
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_negative_count(self, mock_redis, sample_queue):
"""Test pulling tasks with negative count returns empty list."""
result = sample_queue.pull_tasks(-1)
assert result == []
mock_redis.rpop.assert_not_called()
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue):
"""Test pulling tasks that include wrapped objects."""
# Create a wrapped task
task_data = {"task_id": 123, "data": "test"}
wrapper = TaskWrapper(data=task_data)
wrapped_task = wrapper.serialize()
mock_redis.rpop.side_effect = [
"string_task",
wrapped_task.encode("utf-8"), # Simulate bytes from Redis
None,
]
result = sample_queue.pull_tasks(2)
assert len(result) == 2
assert result[0] == "string_task"
assert result[1] == {"task_id": 123, "data": "test"}
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue):
"""Test pulling tasks with invalid JSON falls back to string."""
# Invalid JSON string that cannot be deserialized
invalid_json = "invalid json data"
mock_redis.rpop.side_effect = [invalid_json, None]
result = sample_queue.pull_tasks(1)
assert result == [invalid_json]
@patch("core.rag.pipeline.queue.redis_client")
def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue):
"""Test pulling tasks handles bytes from Redis correctly."""
mock_redis.rpop.side_effect = [
b"task1", # bytes
"task2", # string
None,
]
result = sample_queue.pull_tasks(2)
assert result == ["task1", "task2"]
@patch("core.rag.pipeline.queue.redis_client")
def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue):
"""Test complex object serialization and deserialization roundtrip."""
complex_task = {
"id": uuid4().hex,
"data": {"nested": {"deep": [1, 2, 3], "unicode": "测试中文", "special": "!@#$%^&*()"}},
"metadata": {"created_at": "2024-01-01T00:00:00Z", "tags": ["tag1", "tag2", "tag3"]},
}
# Push the complex task
sample_queue.push_tasks([complex_task])
# Verify it was serialized as TaskWrapper JSON
call_args = mock_redis.lpush.call_args
wrapped_task = call_args[0][1]
# Verify it's a valid TaskWrapper JSON (starts with {"data":)
assert wrapped_task.startswith('{"data":')
# Verify it can be deserialized
wrapper = TaskWrapper.deserialize(wrapped_task)
assert wrapper.data == complex_task
# Simulate pulling it back
mock_redis.rpop.return_value = wrapped_task
result = sample_queue.pull_tasks(1)
assert len(result) == 1
assert result[0] == complex_task

View File

@ -111,26 +111,3 @@ class TestVariablePoolGetAndNestedAttribute:
assert segment_false is not None
assert isinstance(segment_false, BooleanSegment)
assert segment_false.value is False
class TestVariablePoolGetNotModifyVariableDictionary:
_NODE_ID = "start"
_VAR_NAME = "name"
def test_convert_to_template_should_not_introduce_extra_keys(self):
pool = VariablePool.empty()
pool.add([self._NODE_ID, self._VAR_NAME], 0)
pool.convert_template("The start.name is {{#start.name#}}")
assert "The start" not in pool.variable_dictionary
def test_get_should_not_modify_variable_dictionary(self):
pool = VariablePool.empty()
pool.get([self._NODE_ID, self._VAR_NAME])
assert len(pool.variable_dictionary) == 1 # only contains `sys` node id
assert "start" not in pool.variable_dictionary
pool = VariablePool.empty()
pool.add([self._NODE_ID, self._VAR_NAME], "Joe")
pool.get([self._NODE_ID, "count"])
start_subdict = pool.variable_dictionary[self._NODE_ID]
assert "count" not in start_subdict

View File

@ -1,514 +0,0 @@
"""
Comprehensive unit tests for Redis broadcast channel implementation.
This test suite covers all aspects of the Redis broadcast channel including:
- Basic functionality and contract compliance
- Error handling and edge cases
- Thread safety and concurrency
- Resource management and cleanup
- Performance and reliability scenarios
"""
import dataclasses
import threading
import time
from collections.abc import Generator
from unittest.mock import MagicMock, patch
import pytest
from libs.broadcast_channel.exc import BroadcastChannelError, SubscriptionClosedError
from libs.broadcast_channel.redis.channel import (
BroadcastChannel as RedisBroadcastChannel,
)
from libs.broadcast_channel.redis.channel import (
Topic,
_RedisSubscription,
)
class TestBroadcastChannel:
"""Test cases for the main BroadcastChannel class."""
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
"""Create a mock Redis client for testing."""
client = MagicMock()
client.pubsub.return_value = MagicMock()
return client
@pytest.fixture
def broadcast_channel(self, mock_redis_client: MagicMock) -> RedisBroadcastChannel:
"""Create a BroadcastChannel instance with mock Redis client."""
return RedisBroadcastChannel(mock_redis_client)
def test_topic_creation(self, broadcast_channel: RedisBroadcastChannel, mock_redis_client: MagicMock):
"""Test that topic() method returns a Topic instance with correct parameters."""
topic_name = "test-topic"
topic = broadcast_channel.topic(topic_name)
assert isinstance(topic, Topic)
assert topic._client == mock_redis_client
assert topic._topic == topic_name
def test_topic_isolation(self, broadcast_channel: RedisBroadcastChannel):
"""Test that different topic names create isolated Topic instances."""
topic1 = broadcast_channel.topic("topic1")
topic2 = broadcast_channel.topic("topic2")
assert topic1 is not topic2
assert topic1._topic == "topic1"
assert topic2._topic == "topic2"
class TestTopic:
"""Test cases for the Topic class."""
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
"""Create a mock Redis client for testing."""
client = MagicMock()
client.pubsub.return_value = MagicMock()
return client
@pytest.fixture
def topic(self, mock_redis_client: MagicMock) -> Topic:
"""Create a Topic instance for testing."""
return Topic(mock_redis_client, "test-topic")
def test_as_producer_returns_self(self, topic: Topic):
"""Test that as_producer() returns self as Producer interface."""
producer = topic.as_producer()
assert producer is topic
# Producer is a Protocol, check duck typing instead
assert hasattr(producer, "publish")
def test_as_subscriber_returns_self(self, topic: Topic):
"""Test that as_subscriber() returns self as Subscriber interface."""
subscriber = topic.as_subscriber()
assert subscriber is topic
# Subscriber is a Protocol, check duck typing instead
assert hasattr(subscriber, "subscribe")
def test_publish_calls_redis_publish(self, topic: Topic, mock_redis_client: MagicMock):
"""Test that publish() calls Redis PUBLISH with correct parameters."""
payload = b"test message"
topic.publish(payload)
mock_redis_client.publish.assert_called_once_with("test-topic", payload)
@dataclasses.dataclass(frozen=True)
class SubscriptionTestCase:
"""Test case data for subscription tests."""
name: str
buffer_size: int
payload: bytes
expected_messages: list[bytes]
should_drop: bool = False
description: str = ""
class TestRedisSubscription:
"""Test cases for the _RedisSubscription class."""
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
"""Create a mock PubSub instance for testing."""
pubsub = MagicMock()
pubsub.subscribe = MagicMock()
pubsub.unsubscribe = MagicMock()
pubsub.close = MagicMock()
pubsub.get_message = MagicMock()
return pubsub
@pytest.fixture
def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]:
"""Create a _RedisSubscription instance for testing."""
subscription = _RedisSubscription(
pubsub=mock_pubsub,
topic="test-topic",
)
yield subscription
subscription.close()
@pytest.fixture
def started_subscription(self, subscription: _RedisSubscription) -> _RedisSubscription:
"""Create a subscription that has been started."""
subscription._start_if_needed()
return subscription
# ==================== Lifecycle Tests ====================
def test_subscription_initialization(self, mock_pubsub: MagicMock):
"""Test that subscription is properly initialized."""
subscription = _RedisSubscription(
pubsub=mock_pubsub,
topic="test-topic",
)
assert subscription._pubsub is mock_pubsub
assert subscription._topic == "test-topic"
assert not subscription._closed.is_set()
assert subscription._dropped_count == 0
assert subscription._listener_thread is None
assert not subscription._started
def test_start_if_needed_first_call(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that _start_if_needed() properly starts subscription on first call."""
subscription._start_if_needed()
mock_pubsub.subscribe.assert_called_once_with("test-topic")
assert subscription._started is True
assert subscription._listener_thread is not None
def test_start_if_needed_subsequent_calls(self, started_subscription: _RedisSubscription):
"""Test that _start_if_needed() doesn't start subscription on subsequent calls."""
original_thread = started_subscription._listener_thread
started_subscription._start_if_needed()
# Should not create new thread or generator
assert started_subscription._listener_thread is original_thread
def test_start_if_needed_when_closed(self, subscription: _RedisSubscription):
"""Test that _start_if_needed() raises error when subscription is closed."""
subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
subscription._start_if_needed()
def test_start_if_needed_when_cleaned_up(self, subscription: _RedisSubscription):
"""Test that _start_if_needed() raises error when pubsub is None."""
subscription._pubsub = None
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
subscription._start_if_needed()
def test_context_manager_usage(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that subscription works as context manager."""
with subscription as sub:
assert sub is subscription
assert subscription._started is True
mock_pubsub.subscribe.assert_called_once_with("test-topic")
def test_close_idempotent(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that close() is idempotent and can be called multiple times."""
subscription._start_if_needed()
# Close multiple times
subscription.close()
subscription.close()
subscription.close()
# Should only cleanup once
mock_pubsub.unsubscribe.assert_called_once_with("test-topic")
mock_pubsub.close.assert_called_once()
assert subscription._pubsub is None
assert subscription._closed.is_set()
def test_close_cleanup(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that close() properly cleans up all resources."""
subscription._start_if_needed()
thread = subscription._listener_thread
subscription.close()
# Verify cleanup
mock_pubsub.unsubscribe.assert_called_once_with("test-topic")
mock_pubsub.close.assert_called_once()
assert subscription._pubsub is None
assert subscription._listener_thread is None
# Wait for thread to finish (with timeout)
if thread and thread.is_alive():
thread.join(timeout=1.0)
assert not thread.is_alive()
# ==================== Message Processing Tests ====================
def test_message_iterator_with_messages(self, started_subscription: _RedisSubscription):
"""Test message iterator behavior with messages in queue."""
test_messages = [b"msg1", b"msg2", b"msg3"]
# Add messages to queue
for msg in test_messages:
started_subscription._queue.put_nowait(msg)
# Iterate through messages
iterator = iter(started_subscription)
received_messages = []
for msg in iterator:
received_messages.append(msg)
if len(received_messages) >= len(test_messages):
break
assert received_messages == test_messages
def test_message_iterator_when_closed(self, subscription: _RedisSubscription):
"""Test that iterator raises error when subscription is closed."""
subscription.close()
with pytest.raises(BroadcastChannelError, match="The Redis subscription is closed"):
iter(subscription)
# ==================== Message Enqueue Tests ====================
def test_enqueue_message_success(self, started_subscription: _RedisSubscription):
"""Test successful message enqueue."""
payload = b"test message"
started_subscription._enqueue_message(payload)
assert started_subscription._queue.qsize() == 1
assert started_subscription._queue.get_nowait() == payload
def test_enqueue_message_when_closed(self, subscription: _RedisSubscription):
"""Test message enqueue when subscription is closed."""
subscription.close()
payload = b"test message"
# Should not raise exception, but should not enqueue
subscription._enqueue_message(payload)
assert subscription._queue.empty()
def test_enqueue_message_with_full_queue(self, started_subscription: _RedisSubscription):
"""Test message enqueue with full queue (dropping behavior)."""
# Fill the queue
for i in range(started_subscription._queue.maxsize):
started_subscription._queue.put_nowait(f"old_msg_{i}".encode())
# Try to enqueue new message (should drop oldest)
new_message = b"new_message"
started_subscription._enqueue_message(new_message)
# Should have dropped one message and added new one
assert started_subscription._dropped_count == 1
# New message should be in queue
messages = []
while not started_subscription._queue.empty():
messages.append(started_subscription._queue.get_nowait())
assert new_message in messages
# ==================== Listener Thread Tests ====================
@patch("time.sleep", side_effect=lambda x: None) # Speed up test
def test_listener_thread_normal_operation(
self, mock_sleep, subscription: _RedisSubscription, mock_pubsub: MagicMock
):
"""Test listener thread normal operation."""
# Mock message from Redis
mock_message = {"type": "message", "channel": "test-topic", "data": b"test payload"}
mock_pubsub.get_message.return_value = mock_message
# Start listener
subscription._start_if_needed()
# Wait a bit for processing
time.sleep(0.1)
# Verify message was processed
assert not subscription._queue.empty()
assert subscription._queue.get_nowait() == b"test payload"
def test_listener_thread_ignores_subscribe_messages(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that listener thread ignores subscribe/unsubscribe messages."""
mock_message = {"type": "subscribe", "channel": "test-topic", "data": 1}
mock_pubsub.get_message.return_value = mock_message
subscription._start_if_needed()
time.sleep(0.1)
# Should not enqueue subscribe messages
assert subscription._queue.empty()
def test_listener_thread_ignores_wrong_channel(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that listener thread ignores messages from wrong channels."""
mock_message = {"type": "message", "channel": "wrong-topic", "data": b"test payload"}
mock_pubsub.get_message.return_value = mock_message
subscription._start_if_needed()
time.sleep(0.1)
# Should not enqueue messages from wrong channels
assert subscription._queue.empty()
def test_listener_thread_handles_redis_exceptions(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that listener thread handles Redis exceptions gracefully."""
mock_pubsub.get_message.side_effect = Exception("Redis error")
subscription._start_if_needed()
# Wait for thread to handle exception
time.sleep(0.2)
# Thread should still be alive but not processing
assert subscription._listener_thread is not None
assert not subscription._listener_thread.is_alive()
def test_listener_thread_stops_when_closed(self, subscription: _RedisSubscription, mock_pubsub: MagicMock):
"""Test that listener thread stops when subscription is closed."""
subscription._start_if_needed()
thread = subscription._listener_thread
# Close subscription
subscription.close()
# Wait for thread to finish
if thread is not None and thread.is_alive():
thread.join(timeout=1.0)
assert thread is None or not thread.is_alive()
# ==================== Table-driven Tests ====================
@pytest.mark.parametrize(
"test_case",
[
SubscriptionTestCase(
name="basic_message",
buffer_size=5,
payload=b"hello world",
expected_messages=[b"hello world"],
description="Basic message publishing and receiving",
),
SubscriptionTestCase(
name="empty_message",
buffer_size=5,
payload=b"",
expected_messages=[b""],
description="Empty message handling",
),
SubscriptionTestCase(
name="large_message",
buffer_size=5,
payload=b"x" * 10000,
expected_messages=[b"x" * 10000],
description="Large message handling",
),
SubscriptionTestCase(
name="unicode_message",
buffer_size=5,
payload="你好世界".encode(),
expected_messages=["你好世界".encode()],
description="Unicode message handling",
),
],
)
def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
"""Test various subscription scenarios using table-driven approach."""
subscription = _RedisSubscription(
pubsub=mock_pubsub,
topic="test-topic",
)
# Simulate receiving message
mock_message = {"type": "message", "channel": "test-topic", "data": test_case.payload}
mock_pubsub.get_message.return_value = mock_message
try:
with subscription:
# Wait for message processing
time.sleep(0.1)
# Collect received messages
received = []
for msg in subscription:
received.append(msg)
if len(received) >= len(test_case.expected_messages):
break
assert received == test_case.expected_messages, f"Failed: {test_case.description}"
finally:
subscription.close()
def test_concurrent_close_and_enqueue(self, started_subscription: _RedisSubscription):
"""Test concurrent close and enqueue operations."""
errors = []
def close_subscription():
try:
time.sleep(0.05) # Small delay
started_subscription.close()
except Exception as e:
errors.append(e)
def enqueue_messages():
try:
for i in range(50):
started_subscription._enqueue_message(f"msg_{i}".encode())
time.sleep(0.001)
except Exception as e:
errors.append(e)
# Start threads
close_thread = threading.Thread(target=close_subscription)
enqueue_thread = threading.Thread(target=enqueue_messages)
close_thread.start()
enqueue_thread.start()
# Wait for completion
close_thread.join(timeout=2.0)
enqueue_thread.join(timeout=2.0)
# Should not have any errors (operations should be safe)
assert len(errors) == 0
# ==================== Error Handling Tests ====================
def test_iterator_after_close(self, subscription: _RedisSubscription):
"""Test iterator behavior after close."""
subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
iter(subscription)
def test_start_after_close(self, subscription: _RedisSubscription):
"""Test start attempts after close."""
subscription.close()
with pytest.raises(SubscriptionClosedError, match="The Redis subscription is closed"):
subscription._start_if_needed()
def test_pubsub_none_operations(self, subscription: _RedisSubscription):
"""Test operations when pubsub is None."""
subscription._pubsub = None
with pytest.raises(SubscriptionClosedError, match="The Redis subscription has been cleaned up"):
subscription._start_if_needed()
# Close should still work
subscription.close() # Should not raise
def test_channel_name_variations(self, mock_pubsub: MagicMock):
"""Test various channel name formats."""
channel_names = [
"simple",
"with-dashes",
"with_underscores",
"with.numbers",
"WITH.UPPERCASE",
"mixed-CASE_name",
"very.long.channel.name.with.multiple.parts",
]
for channel_name in channel_names:
subscription = _RedisSubscription(
pubsub=mock_pubsub,
topic=channel_name,
)
subscription._start_if_needed()
mock_pubsub.subscribe.assert_called_with(channel_name)
subscription.close()
def test_received_on_closed_subscription(self, subscription: _RedisSubscription):
subscription.close()
with pytest.raises(SubscriptionClosedError):
subscription.receive()

View File

@ -1,317 +0,0 @@
from unittest.mock import Mock, patch
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
class DocumentIndexingTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantIsolatedTaskQueue."""
queue = Mock(spec=TenantIsolatedTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_document_task_proxy(
tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None
) -> DocumentIndexingTaskProxy:
"""Create DocumentIndexingTaskProxy instance for testing."""
if document_ids is None:
document_ids = ["doc-1", "doc-2", "doc-3"]
return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
class TestDocumentIndexingTaskProxy:
"""Test cases for DocumentIndexingTaskProxy class."""
def test_initialization(self):
"""Test DocumentIndexingTaskProxy initialization."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2", "doc-3"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing"
@patch("services.document_indexing_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(mock_task)
# Assert
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once()
pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0]
assert len(pushed_tasks) == 1
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
mock_task.delay.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._tenant_isolated_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"]
)
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.document_indexing_task_proxy.normal_document_indexing_task")
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_default_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_priority_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.priority_document_indexing_task")
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_direct_queue = Mock()
# Act
proxy._send_to_priority_direct_queue()
# Assert
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.TEAM
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# If billing enabled with non sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
# Act
proxy._dispatch()
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_delay_method(self, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy.delay()
# Assert
# If billing enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once()
def test_document_task_dataclass(self):
"""Test DocumentTask dataclass."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2"]
# Act
task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
# Assert
assert task.tenant_id == tenant_id
assert task.dataset_id == dataset_id
assert task.document_ids == document_ids
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch("services.document_indexing_task_proxy.FeatureService")
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
def test_initialization_with_empty_document_ids(self):
"""Test initialization with empty document_ids list."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = []
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids
def test_initialization_with_single_document_id(self):
"""Test initialization with single document_id."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy._tenant_id == tenant_id
assert proxy._dataset_id == dataset_id
assert proxy._document_ids == document_ids

View File

@ -1,483 +0,0 @@
import json
from unittest.mock import Mock, patch
import pytest
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
class RagPipelineTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for RagPipelineTaskProxy tests."""
@staticmethod
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantIsolatedTaskQueue."""
queue = Mock(spec=TenantIsolatedTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_rag_pipeline_invoke_entity(
pipeline_id: str = "pipeline-123",
user_id: str = "user-456",
tenant_id: str = "tenant-789",
workflow_id: str = "workflow-101",
streaming: bool = True,
workflow_execution_id: str | None = None,
workflow_thread_pool_id: str | None = None,
) -> RagPipelineInvokeEntity:
"""Create RagPipelineInvokeEntity instance for testing."""
return RagPipelineInvokeEntity(
pipeline_id=pipeline_id,
application_generate_entity={"key": "value"},
user_id=user_id,
tenant_id=tenant_id,
workflow_id=workflow_id,
streaming=streaming,
workflow_execution_id=workflow_execution_id,
workflow_thread_pool_id=workflow_thread_pool_id,
)
@staticmethod
def create_rag_pipeline_task_proxy(
dataset_tenant_id: str = "tenant-123",
user_id: str = "user-456",
rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None,
) -> RagPipelineTaskProxy:
"""Create RagPipelineTaskProxy instance for testing."""
if rag_pipeline_invoke_entities is None:
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
@staticmethod
def create_mock_upload_file(file_id: str = "file-123") -> Mock:
"""Create mock upload file."""
upload_file = Mock()
upload_file.id = file_id
return upload_file
class TestRagPipelineTaskProxy:
"""Test cases for RagPipelineTaskProxy class."""
def test_initialization(self):
"""Test RagPipelineTaskProxy initialization."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy._dataset_tenant_id == dataset_tenant_id
assert proxy._user_id == user_id
assert proxy._rag_pipeline_invoke_entities == rag_pipeline_invoke_entities
assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue)
assert proxy._tenant_isolated_task_queue._tenant_id == dataset_tenant_id
assert proxy._tenant_isolated_task_queue._unique_key == "pipeline"
def test_initialization_with_empty_entities(self):
"""Test initialization with empty rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = []
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy._dataset_tenant_id == dataset_tenant_id
assert proxy._user_id == user_id
assert proxy._rag_pipeline_invoke_entities == []
def test_initialization_with_multiple_entities(self):
"""Test initialization with multiple rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3"),
]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert len(proxy._rag_pipeline_invoke_entities) == 3
assert proxy._rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1"
assert proxy._rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2"
assert proxy._rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-123"
mock_file_service_class.assert_called_once_with(mock_db.engine)
# Verify upload_text was called with correct parameters
mock_file_service.upload_text.assert_called_once()
call_args = mock_file_service.upload_text.call_args
json_text, name, user_id, tenant_id = call_args[0]
assert name == "rag_pipeline_invoke_entities.json"
assert user_id == "user-456"
assert tenant_id == "tenant-123"
# Verify JSON content
parsed_json = json.loads(json_text)
assert len(parsed_json) == 1
assert parsed_json[0]["pipeline_id"] == "pipeline-123"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method with multiple entities."""
# Arrange
entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
]
proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities)
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-456"
# Verify JSON content contains both entities
call_args = mock_file_service.upload_text.call_args
json_text = call_args[0][0]
parsed_json = json.loads(json_text)
assert len(parsed_json) == 2
assert parsed_json[0]["pipeline_id"] == "pipeline-1"
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue()
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(upload_file_id, mock_task)
# If sent to direct queue, tenant_isolated_task_queue should not be called
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
# Celery should be called directly
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If task key exists, should push tasks to the queue
proxy._tenant_isolated_task_queue.push_tasks.assert_called_once_with([upload_file_id])
# Celery should not be called directly
mock_task.delay.assert_not_called()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._tenant_isolated_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If no task key, should set task waiting time key first
proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id="tenant-123"
)
# The first task should be sent to celery directly, so push tasks should not be called
proxy._tenant_isolated_task_queue.push_tasks.assert_not_called()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task")
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_default_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task")
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_direct_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_direct_queue(upload_file_id)
# Assert
proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task)
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_enabled_non_sandbox_plan(
self, mock_db, mock_file_service_class, mock_feature_service
):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.TEAM
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with non-sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=False)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
"""Test _dispatch method when upload_file_id is empty."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = Mock()
mock_upload_file.id = "" # Empty file ID
mock_file_service.upload_text.return_value = mock_upload_file
# Act & Assert
with pytest.raises(ValueError, match="upload_file_id is empty"):
proxy._dispatch()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan="")
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(billing_enabled=True, plan=None)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FeatureService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.FileService")
@patch("services.rag_pipeline.rag_pipeline_task_proxy.db")
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=CloudPlan.SANDBOX
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._dispatch = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy.delay()
# Assert
proxy._dispatch.assert_called_once()
@patch("services.rag_pipeline.rag_pipeline_task_proxy.logger")
def test_delay_method_with_empty_entities(self, mock_logger):
"""Test delay method with empty rag_pipeline_invoke_entities."""
# Arrange
proxy = RagPipelineTaskProxy("tenant-123", "user-456", [])
# Act
proxy.delay()
# Assert
mock_logger.warning.assert_called_once_with(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s", "tenant-123", "user-456"
)

4597
api/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,30 @@
id: remove-nullable-arg
language: python
rule:
pattern: $X = mapped_column($$$ARGS)
any:
- pattern: $X = mapped_column($$$BEFORE, sa.String, $$$MID, nullable=True, $$$AFTER)
- pattern: $X = mapped_column($$$BEFORE, sa.String, $$$MID, nullable=True)
rewriters:
- id: filter-string-nullable
rule:
pattern: $ARG
inside:
kind: argument_list
all:
- not:
pattern: String
- not:
pattern:
context: a(nullable=True)
selector: keyword_argument
fix: $ARG
transform:
NEWARGS:
rewrite:
rewriters: [filter-string-nullable]
source: $$$ARGS
joinBy: ', '
fix: |-
$X: Mapped[str | None] = mapped_column($NEWARGS)

View File

@ -7,4 +7,4 @@ cd "$SCRIPT_DIR/.."
uv --directory api run \
celery -A app.celery worker \
-P gevent -c 1 --loglevel INFO -Q dataset,priority_dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
-P gevent -c 1 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline

View File

@ -492,7 +492,6 @@ VECTOR_INDEX_NAME_PREFIX=Vector_index
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
WEAVIATE_ENDPOINT=http://weaviate:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENDPOINT=grpc://weaviate:50051
# The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`.
QDRANT_URL=http://qdrant:6333
@ -1370,6 +1369,3 @@ ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1

View File

@ -157,7 +157,6 @@ x-shared-env: &shared-api-worker-env
VECTOR_INDEX_NAME_PREFIX: ${VECTOR_INDEX_NAME_PREFIX:-Vector_index}
WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080}
WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih}
WEAVIATE_GRPC_ENDPOINT: ${WEAVIATE_GRPC_ENDPOINT:-grpc://weaviate:50051}
QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333}
QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456}
QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20}
@ -614,7 +613,6 @@ x-shared-env: &shared-api-worker-env
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false}
ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false}
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true}
TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1}
services:
# API service

View File

@ -117,7 +117,7 @@ Tutte le offerte di Dify sono dotate di API corrispondenti, permettendovi di int
Avviate rapidamente Dify nel vostro ambiente con questa [guida di avvio rapido](#avvio-rapido). Utilizzate la nostra [documentazione](https://docs.dify.ai) per ulteriori informazioni e istruzioni dettagliate.
- **Dify per Aziende / Organizzazioni<br/>**
Offriamo funzionalità aggiuntive specifiche per le aziende. Potete [scriverci via email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) per discutere le vostre esigenze aziendali. <br/>
Offriamo funzionalità aggiuntive specifiche per le aziende. [Potete comunicarci le vostre domande tramite questo chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) o [inviateci un'email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) per discutere le vostre esigenze aziendali. <br/>
> Per startup e piccole imprese che utilizzano AWS, date un'occhiata a [Dify Premium su AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e distribuitelo con un solo clic nel vostro AWS VPC. Si tratta di un'offerta AMI conveniente con l'opzione di creare app con logo e branding personalizzati.

View File

@ -91,7 +91,7 @@ Todas os recursos do Dify vêm com APIs correspondentes, permitindo que você in
Use nossa [documentação](https://docs.dify.ai) para referências adicionais e instruções mais detalhadas.
- **Dify para empresas/organizações</br>**
Oferecemos recursos adicionais voltados para empresas. Você pode [falar conosco por e-mail](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) para discutir necessidades empresariais. <br/>
Oferecemos recursos adicionais voltados para empresas. [Envie suas perguntas através deste chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) ou [envie-nos um e-mail](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) para discutir necessidades empresariais. </br>
> Para startups e pequenas empresas que utilizam AWS, confira o [Dify Premium no AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e implemente no seu próprio AWS VPC com um clique. É uma oferta AMI acessível com a opção de criar aplicativos com logotipo e marca personalizados.

View File

@ -86,7 +86,7 @@ Tất cả các dịch vụ của Dify đều đi kèm với các API tương
Sử dụng [tài liệu](https://docs.dify.ai) của chúng tôi để tham khảo thêm và nhận hướng dẫn chi tiết hơn.
- **Dify cho doanh nghiệp / tổ chức</br>**
Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Gửi email cho chúng tôi](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp. <br/>
Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Ghi lại câu hỏi của bạn cho chúng tôi thông qua chatbot này](https://udify.app/chat/22L1zSxg6yW1cWQg) hoặc [gửi email cho chúng tôi](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp. </br>
> Đối với các công ty khởi nghiệp và doanh nghiệp nhỏ sử dụng AWS, hãy xem [Dify Premium trên AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) và triển khai nó vào AWS VPC của riêng bạn chỉ với một cú nhấp chuột. Đây là một AMI giá cả phải chăng với tùy chọn tạo ứng dụng với logo và thương hiệu tùy chỉnh.

View File

@ -44,32 +44,9 @@ fi
if $web_modified; then
echo "Running ESLint on web module"
if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then
web_ts_modified=false
else
ts_diff_status=$?
if [ $ts_diff_status -eq 1 ]; then
web_ts_modified=true
else
echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)."
exit $ts_diff_status
fi
fi
cd ./web || exit 1
lint-staged
if $web_ts_modified; then
echo "Running TypeScript type-check"
if ! pnpm run type-check; then
echo "Type check failed. Please run 'pnpm run type-check' to fix the errors."
exit 1
fi
else
echo "No staged TypeScript changes detected, skipping type-check"
fi
echo "Running unit tests check"
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)

View File

@ -5,22 +5,15 @@ import quarterOfYear from 'dayjs/plugin/quarterOfYear'
import { useTranslation } from 'react-i18next'
import type { PeriodParams } from '@/app/components/app/overview/app-chart'
import { AvgResponseTime, AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart, EndUsersChart, MessagesChart, TokenPerSecond, UserSatisfactionRate, WorkflowCostChart, WorkflowDailyTerminalsChart, WorkflowMessagesChart } from '@/app/components/app/overview/app-chart'
import type { Item } from '@/app/components/base/select'
import { SimpleSelect } from '@/app/components/base/select'
import { TIME_PERIOD_MAPPING } from '@/app/components/app/log/filter'
import { useStore as useAppStore } from '@/app/components/app/store'
import TimeRangePicker from './time-range-picker'
import { TIME_PERIOD_MAPPING as LONG_TIME_PERIOD_MAPPING } from '@/app/components/app/log/filter'
import { IS_CLOUD_EDITION } from '@/config'
import LongTimeRangePicker from './long-time-range-picker'
dayjs.extend(quarterOfYear)
const today = dayjs()
const TIME_PERIOD_MAPPING = [
{ value: 0, name: 'today' },
{ value: 7, name: 'last7days' },
{ value: 30, name: 'last30days' },
]
const queryDateFormat = 'YYYY-MM-DD HH:mm'
export type IChartViewProps = {
@ -33,10 +26,21 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) {
const appDetail = useAppStore(state => state.appDetail)
const isChatApp = appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow'
const isWorkflow = appDetail?.mode === 'workflow'
const [period, setPeriod] = useState<PeriodParams>(IS_CLOUD_EDITION
? { name: t('appLog.filter.period.today'), query: { start: today.startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }
: { name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } },
)
const [period, setPeriod] = useState<PeriodParams>({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } })
const onSelect = (item: Item) => {
if (item.value === -1) {
setPeriod({ name: item.name, query: undefined })
}
else if (item.value === 0) {
const startOfToday = today.startOf('day').format(queryDateFormat)
const endOfToday = today.endOf('day').format(queryDateFormat)
setPeriod({ name: item.name, query: { start: startOfToday, end: endOfToday } })
}
else {
setPeriod({ name: item.name, query: { start: today.subtract(item.value as number, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } })
}
}
if (!appDetail)
return null
@ -46,20 +50,20 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) {
<div className='mb-4'>
<div className='system-xl-semibold mb-2 text-text-primary'>{t('common.appMenus.overview')}</div>
<div className='flex items-center justify-between'>
{IS_CLOUD_EDITION ? (
<TimeRangePicker
ranges={TIME_PERIOD_MAPPING}
onSelect={setPeriod}
queryDateFormat={queryDateFormat}
<div className='flex flex-row items-center'>
<SimpleSelect
items={Object.entries(TIME_PERIOD_MAPPING).map(([k, v]) => ({ value: k, name: t(`appLog.filter.period.${v.name}`) }))}
className='mt-0 !w-40'
notClearable={true}
onSelect={(item) => {
const id = item.value
const value = TIME_PERIOD_MAPPING[id]?.value ?? '-1'
const name = item.name || t('appLog.filter.period.allTime')
onSelect({ value, name })
}}
defaultValue={'2'}
/>
) : (
<LongTimeRangePicker
periodMapping={LONG_TIME_PERIOD_MAPPING}
onSelect={setPeriod}
queryDateFormat={queryDateFormat}
/>
)}
</div>
{headerRight}
</div>
</div>

View File

@ -1,63 +0,0 @@
'use client'
import type { PeriodParams } from '@/app/components/app/overview/app-chart'
import type { FC } from 'react'
import React from 'react'
import type { Item } from '@/app/components/base/select'
import { SimpleSelect } from '@/app/components/base/select'
import { useTranslation } from 'react-i18next'
import dayjs from 'dayjs'
type Props = {
periodMapping: { [key: string]: { value: number; name: string } }
onSelect: (payload: PeriodParams) => void
queryDateFormat: string
}
const today = dayjs()
const LongTimeRangePicker: FC<Props> = ({
periodMapping,
onSelect,
queryDateFormat,
}) => {
const { t } = useTranslation()
const handleSelect = React.useCallback((item: Item) => {
const id = item.value
const value = periodMapping[id]?.value ?? '-1'
const name = item.name || t('appLog.filter.period.allTime')
if (value === -1) {
onSelect({ name: t('appLog.filter.period.allTime'), query: undefined })
}
else if (value === 0) {
const startOfToday = today.startOf('day').format(queryDateFormat)
const endOfToday = today.endOf('day').format(queryDateFormat)
onSelect({
name,
query: {
start: startOfToday,
end: endOfToday,
},
})
}
else {
onSelect({
name,
query: {
start: today.subtract(value as number, 'day').startOf('day').format(queryDateFormat),
end: today.endOf('day').format(queryDateFormat),
},
})
}
}, [onSelect, periodMapping, queryDateFormat, t])
return (
<SimpleSelect
items={Object.entries(periodMapping).map(([k, v]) => ({ value: k, name: t(`appLog.filter.period.${v.name}`) }))}
className='mt-0 !w-40'
notClearable={true}
onSelect={handleSelect}
defaultValue={'2'}
/>
)
}
export default React.memo(LongTimeRangePicker)

View File

@ -1,80 +0,0 @@
'use client'
import { RiCalendarLine } from '@remixicon/react'
import type { Dayjs } from 'dayjs'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import cn from '@/utils/classnames'
import { formatToLocalTime } from '@/utils/format'
import { useI18N } from '@/context/i18n'
import Picker from '@/app/components/base/date-and-time-picker/date-picker'
import type { TriggerProps } from '@/app/components/base/date-and-time-picker/types'
import { noop } from 'lodash-es'
import dayjs from 'dayjs'
type Props = {
start: Dayjs
end: Dayjs
onStartChange: (date?: Dayjs) => void
onEndChange: (date?: Dayjs) => void
}
const today = dayjs()
const DatePicker: FC<Props> = ({
start,
end,
onStartChange,
onEndChange,
}) => {
const { locale } = useI18N()
const renderDate = useCallback(({ value, handleClickTrigger, isOpen }: TriggerProps) => {
return (
<div className={cn('system-sm-regular flex h-7 cursor-pointer items-center rounded-lg px-1 text-components-input-text-filled hover:bg-state-base-hover', isOpen && 'bg-state-base-hover')} onClick={handleClickTrigger}>
{value ? formatToLocalTime(value, locale, 'MMM D') : ''}
</div>
)
}, [locale])
const availableStartDate = end.subtract(30, 'day')
const startDateDisabled = useCallback((date: Dayjs) => {
if (date.isAfter(today, 'date'))
return true
return !((date.isAfter(availableStartDate, 'date') || date.isSame(availableStartDate, 'date')) && (date.isBefore(end, 'date') || date.isSame(end, 'date')))
}, [availableStartDate, end])
const availableEndDate = start.add(30, 'day')
const endDateDisabled = useCallback((date: Dayjs) => {
if (date.isAfter(today, 'date'))
return true
return !((date.isAfter(start, 'date') || date.isSame(start, 'date')) && (date.isBefore(availableEndDate, 'date') || date.isSame(availableEndDate, 'date')))
}, [availableEndDate, start])
return (
<div className='flex h-8 items-center space-x-0.5 rounded-lg bg-components-input-bg-normal px-2'>
<div className='p-px'>
<RiCalendarLine className='size-3.5 text-text-tertiary' />
</div>
<Picker
value={start}
onChange={onStartChange}
renderTrigger={renderDate}
needTimePicker={false}
onClear={noop}
noConfirm
getIsDateDisabled={startDateDisabled}
/>
<span className='system-sm-regular text-text-tertiary'>-</span>
<Picker
value={end}
onChange={onEndChange}
renderTrigger={renderDate}
needTimePicker={false}
onClear={noop}
noConfirm
getIsDateDisabled={endDateDisabled}
/>
</div>
)
}
export default React.memo(DatePicker)

View File

@ -1,86 +0,0 @@
'use client'
import type { PeriodParams, PeriodParamsWithTimeRange } from '@/app/components/app/overview/app-chart'
import type { FC } from 'react'
import React, { useCallback, useState } from 'react'
import type { Dayjs } from 'dayjs'
import { HourglassShape } from '@/app/components/base/icons/src/vender/other'
import RangeSelector from './range-selector'
import DatePicker from './date-picker'
import dayjs from 'dayjs'
import { useI18N } from '@/context/i18n'
import { formatToLocalTime } from '@/utils/format'
const today = dayjs()
type Props = {
ranges: { value: number; name: string }[]
onSelect: (payload: PeriodParams) => void
queryDateFormat: string
}
const TimeRangePicker: FC<Props> = ({
ranges,
onSelect,
queryDateFormat,
}) => {
const { locale } = useI18N()
const [isCustomRange, setIsCustomRange] = useState(false)
const [start, setStart] = useState<Dayjs>(today)
const [end, setEnd] = useState<Dayjs>(today)
const handleRangeChange = useCallback((payload: PeriodParamsWithTimeRange) => {
setIsCustomRange(false)
setStart(payload.query!.start)
setEnd(payload.query!.end)
onSelect({
name: payload.name,
query: {
start: payload.query!.start.format(queryDateFormat),
end: payload.query!.end.format(queryDateFormat),
},
})
}, [onSelect, queryDateFormat])
const handleDateChange = useCallback((type: 'start' | 'end') => {
return (date?: Dayjs) => {
if (!date) return
if (type === 'start' && date.isSame(start)) return
if (type === 'end' && date.isSame(end)) return
if (type === 'start')
setStart(date)
else
setEnd(date)
const currStart = type === 'start' ? date : start
const currEnd = type === 'end' ? date : end
onSelect({
name: `${formatToLocalTime(currStart, locale, 'MMM D')} - ${formatToLocalTime(currEnd, locale, 'MMM D')}`,
query: {
start: currStart.format(queryDateFormat),
end: currEnd.format(queryDateFormat),
},
})
setIsCustomRange(true)
}
}, [start, end, onSelect, locale, queryDateFormat])
return (
<div className='flex items-center'>
<RangeSelector
isCustomRange={isCustomRange}
ranges={ranges}
onSelect={handleRangeChange}
/>
<HourglassShape className='h-3.5 w-2 text-components-input-bg-normal' />
<DatePicker
start={start}
end={end}
onStartChange={handleDateChange('start')}
onEndChange={handleDateChange('end')}
/>
</div>
)
}
export default React.memo(TimeRangePicker)

View File

@ -1,81 +0,0 @@
'use client'
import type { PeriodParamsWithTimeRange, TimeRange } from '@/app/components/app/overview/app-chart'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import { SimpleSelect } from '@/app/components/base/select'
import type { Item } from '@/app/components/base/select'
import dayjs from 'dayjs'
import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react'
import cn from '@/utils/classnames'
import { useTranslation } from 'react-i18next'
const today = dayjs()
type Props = {
isCustomRange: boolean
ranges: { value: number; name: string }[]
onSelect: (payload: PeriodParamsWithTimeRange) => void
}
const RangeSelector: FC<Props> = ({
isCustomRange,
ranges,
onSelect,
}) => {
const { t } = useTranslation()
const handleSelectRange = useCallback((item: Item) => {
const { name, value } = item
let period: TimeRange | null = null
if (value === 0) {
const startOfToday = today.startOf('day')
const endOfToday = today.endOf('day')
period = { start: startOfToday, end: endOfToday }
}
else {
period = { start: today.subtract(item.value as number, 'day').startOf('day'), end: today.endOf('day') }
}
onSelect({ query: period!, name })
}, [onSelect])
const renderTrigger = useCallback((item: Item | null, isOpen: boolean) => {
return (
<div className={cn('flex h-8 cursor-pointer items-center space-x-1.5 rounded-lg bg-components-input-bg-normal pl-3 pr-2', isOpen && 'bg-state-base-hover-alt')}>
<div className='system-sm-regular text-components-input-text-filled'>{isCustomRange ? t('appLog.filter.period.custom') : item?.name}</div>
<RiArrowDownSLine className={cn('size-4 text-text-quaternary', isOpen && 'text-text-secondary')} />
</div>
)
}, [isCustomRange])
const renderOption = useCallback(({ item, selected }: { item: Item; selected: boolean }) => {
return (
<>
{selected && (
<span
className={cn(
'absolute left-2 top-[9px] flex items-center text-text-accent',
)}
>
<RiCheckLine className="h-4 w-4" aria-hidden="true" />
</span>
)}
<span className={cn('system-md-regular block truncate')}>{item.name}</span>
</>
)
}, [])
return (
<SimpleSelect
items={ranges.map(v => ({ ...v, name: t(`appLog.filter.period.${v.name}`) }))}
className='mt-0 !w-40'
notClearable={true}
onSelect={handleSelectRange}
defaultValue={0}
wrapperClassName='h-8'
optionWrapClassName='w-[200px] translate-x-[-24px]'
renderTrigger={renderTrigger}
optionClassName='flex items-center py-0 pl-7 pr-2 h-8'
renderOption={renderOption}
/>
)
}
export default React.memo(RangeSelector)

View File

@ -10,10 +10,6 @@ import { ProviderContextProvider } from '@/context/provider-context'
import { ModalContextProvider } from '@/context/modal-context'
import GotoAnything from '@/app/components/goto-anything'
import Zendesk from '@/app/components/base/zendesk'
import Splash from '../components/splash'
import Test from '@edition/test'
import SubSubIndex from '@edition/sub/sub-sub/index'
import SubSub from '@edition/sub/sub-sub'
const Layout = ({ children }: { children: ReactNode }) => {
return (
@ -24,15 +20,11 @@ const Layout = ({ children }: { children: ReactNode }) => {
<EventEmitterContextProvider>
<ProviderContextProvider>
<ModalContextProvider>
<Test />
<SubSubIndex />
<SubSub />
<HeaderWrapper>
<Header />
</HeaderWrapper>
{children}
<GotoAnything />
<Splash />
</ModalContextProvider>
</ProviderContextProvider>
</EventEmitterContextProvider>

View File

@ -124,7 +124,7 @@ const AppOperations = ({ operations, gap }: {
<span className='system-xs-medium text-components-button-secondary-text'>{t('common.operation.more')}</span>
</Button>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[30]'>
<PortalToFollowElemContent className='z-[21]'>
<div className='flex min-w-[264px] flex-col rounded-[12px] border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg backdrop-blur-[5px]'>
{moreOperations.map(item => <div
key={item.id}

View File

@ -4,7 +4,6 @@ import React from 'react'
import ReactECharts from 'echarts-for-react'
import type { EChartsOption } from 'echarts'
import useSWR from 'swr'
import type { Dayjs } from 'dayjs'
import dayjs from 'dayjs'
import { get } from 'lodash-es'
import Decimal from 'decimal.js'
@ -79,16 +78,6 @@ export type PeriodParams = {
}
}
export type TimeRange = {
start: Dayjs
end: Dayjs
}
export type PeriodParamsWithTimeRange = {
name: string
query?: TimeRange
}
export type IBizChartProps = {
period: PeriodParams
id: string
@ -226,7 +215,9 @@ const Chart: React.FC<IChartProps> = ({
formatter(params) {
return `<div style='color:#6B7280;font-size:12px'>${params.name}</div>
<div style='font-size:14px;color:#1F2A37'>${valueFormatter((params.data as any)[yField])}
${!CHART_TYPE_CONFIG[chartType].showTokens ? '' : `<span style='font-size:12px'>
${!CHART_TYPE_CONFIG[chartType].showTokens
? ''
: `<span style='font-size:12px'>
<span style='margin-left:4px;color:#6B7280'>(</span>
<span style='color:#FF8A4C'>~$${get(params.data, 'total_price', 0)}</span>
<span style='color:#6B7280'>)</span>

View File

@ -49,7 +49,7 @@ const InputsFormContent = ({ showTip }: Props) => {
<div className='flex h-6 items-center gap-1'>
<div className='system-md-semibold text-text-secondary'>{form.label}</div>
{!form.required && (
<div className='system-xs-regular text-text-tertiary'>{t('workflow.panel.optional')}</div>
<div className='system-xs-regular text-text-tertiary'>{t('appDebug.variableTable.optional')}</div>
)}
</div>
)}

View File

@ -49,7 +49,7 @@ const InputsFormContent = ({ showTip }: Props) => {
<div className='flex h-6 items-center gap-1'>
<div className='system-md-semibold text-text-secondary'>{form.label}</div>
{!form.required && (
<div className='system-xs-regular text-text-tertiary'>{t('workflow.panel.optional')}</div>
<div className='system-xs-regular text-text-tertiary'>{t('appDebug.variableTable.optional')}</div>
)}
</div>
)}

View File

@ -8,10 +8,9 @@ const Calendar: FC<CalendarProps> = ({
selectedDate,
onDateClick,
wrapperClassName,
getIsDateDisabled,
}) => {
return <div className={wrapperClassName}>
<DaysOfWeek />
<DaysOfWeek/>
<div className='grid grid-cols-7 gap-0.5 p-2'>
{
days.map(day => <CalendarItem
@ -19,7 +18,6 @@ const Calendar: FC<CalendarProps> = ({
day={day}
selectedDate={selectedDate}
onClick={onDateClick}
isDisabled={getIsDateDisabled ? getIsDateDisabled(day.date) : false}
/>)
}
</div>

View File

@ -7,7 +7,6 @@ const Item: FC<CalendarItemProps> = ({
day,
selectedDate,
onClick,
isDisabled,
}) => {
const { date, isCurrentMonth } = day
const isSelected = selectedDate?.isSame(date, 'date')
@ -15,12 +14,11 @@ const Item: FC<CalendarItemProps> = ({
return (
<button type="button"
onClick={() => !isDisabled && onClick(date)}
onClick={() => onClick(date)}
className={cn(
'system-sm-medium relative flex items-center justify-center rounded-lg px-1 py-2',
isCurrentMonth ? 'text-text-secondary' : 'text-text-quaternary hover:text-text-secondary',
isSelected ? 'system-sm-medium bg-components-button-primary-bg text-components-button-primary-text' : 'hover:bg-state-base-hover',
isDisabled && 'cursor-not-allowed text-text-quaternary hover:bg-transparent',
)}
>
{date.date()}

View File

@ -36,8 +36,6 @@ const DatePicker = ({
renderTrigger,
triggerWrapClassName,
popupZIndexClassname = 'z-[11]',
noConfirm,
getIsDateDisabled,
}: DatePickerProps) => {
const { t } = useTranslation()
const [isOpen, setIsOpen] = useState(false)
@ -122,20 +120,11 @@ const DatePicker = ({
setCurrentDate(currentDate.clone().subtract(1, 'month'))
}, [currentDate])
const handleConfirmDate = useCallback((passedInSelectedDate?: Dayjs) => {
// passedInSelectedDate may be a click event when noConfirm is false
const nextDate = (dayjs.isDayjs(passedInSelectedDate) ? passedInSelectedDate : selectedDate)
onChange(nextDate ? nextDate.tz(timezone) : undefined)
setIsOpen(false)
}, [selectedDate, onChange, timezone])
const handleDateSelect = useCallback((day: Dayjs) => {
const newDate = cloneTime(day, selectedDate || getDateWithTimezone({ timezone }))
setCurrentDate(newDate)
setSelectedDate(newDate)
if (noConfirm)
handleConfirmDate(newDate)
}, [selectedDate, timezone, noConfirm, handleConfirmDate])
}, [selectedDate, timezone])
const handleSelectCurrentDate = () => {
const newDate = getDateWithTimezone({ timezone })
@ -145,6 +134,12 @@ const DatePicker = ({
setIsOpen(false)
}
const handleConfirmDate = () => {
// debugger
onChange(selectedDate ? selectedDate.tz(timezone) : undefined)
setIsOpen(false)
}
const handleClickTimePicker = () => {
if (view === ViewType.date) {
setView(ViewType.time)
@ -275,7 +270,6 @@ const DatePicker = ({
days={days}
selectedDate={selectedDate}
onDateClick={handleDateSelect}
getIsDateDisabled={getIsDateDisabled}
/>
) : view === ViewType.yearMonth ? (
<YearAndMonthPickerOptions
@ -296,7 +290,7 @@ const DatePicker = ({
{/* Footer */}
{
[ViewType.date, ViewType.time].includes(view) && !noConfirm && (
[ViewType.date, ViewType.time].includes(view) ? (
<DatePickerFooter
needTimePicker={needTimePicker}
displayTime={displayTime}
@ -305,10 +299,7 @@ const DatePicker = ({
handleSelectCurrentDate={handleSelectCurrentDate}
handleConfirmDate={handleConfirmDate}
/>
)
}
{
![ViewType.date, ViewType.time].includes(view) && (
) : (
<YearAndMonthPickerFooter
handleYearMonthCancel={handleYearMonthCancel}
handleYearMonthConfirm={handleYearMonthConfirm}

View File

@ -30,8 +30,6 @@ export type DatePickerProps = {
renderTrigger?: (props: TriggerProps) => React.ReactNode
minuteFilter?: (minutes: string[]) => string[]
popupZIndexClassname?: string
noConfirm?: boolean
getIsDateDisabled?: (date: Dayjs) => boolean
}
export type DatePickerHeaderProps = {
@ -82,14 +80,12 @@ export type CalendarProps = {
selectedDate: Dayjs | undefined
onDateClick: (date: Dayjs) => void
wrapperClassName?: string
getIsDateDisabled?: (date: Dayjs) => boolean
}
export type CalendarItemProps = {
day: Day
selectedDate: Dayjs | undefined
onClick: (date: Dayjs) => void
isDisabled: boolean
}
export type TimeOptionsProps = {

View File

@ -1,4 +1,4 @@
import React, { useCallback, useEffect, useMemo, useState } from 'react'
import React, { useCallback, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks'
import { produce } from 'immer'
@ -45,13 +45,7 @@ const OpeningSettingModal = ({
const [isShowConfirmAddVar, { setTrue: showConfirmAddVar, setFalse: hideConfirmAddVar }] = useBoolean(false)
const [notIncludeKeys, setNotIncludeKeys] = useState<string[]>([])
const isSaveDisabled = useMemo(() => !tempValue.trim(), [tempValue])
const handleSave = useCallback((ignoreVariablesCheck?: boolean) => {
// Prevent saving if opening statement is empty
if (isSaveDisabled)
return
if (!ignoreVariablesCheck) {
const keys = getInputKeys(tempValue)
const promptKeys = promptVariables.map(item => item.key)
@ -81,7 +75,7 @@ const OpeningSettingModal = ({
}
})
onSave(newOpening)
}, [data, onSave, promptVariables, workflowVariables, showConfirmAddVar, tempSuggestedQuestions, tempValue, isSaveDisabled])
}, [data, onSave, promptVariables, workflowVariables, showConfirmAddVar, tempSuggestedQuestions, tempValue])
const cancelAutoAddVar = useCallback(() => {
hideConfirmAddVar()
@ -223,7 +217,6 @@ const OpeningSettingModal = ({
<Button
variant='primary'
onClick={() => handleSave()}
disabled={isSaveDisabled}
>
{t('common.operation.save')}
</Button>

View File

@ -1,3 +0,0 @@
<svg width="8" height="14" viewBox="0 0 8 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 14C8 11.7909 6.20914 10 4 10C1.79086 10 0 11.7909 0 14V0C8.05332e-08 2.20914 1.79086 4 4 4C6.20914 4 8 2.20914 8 0V14Z" fill="#C8CEDA" fill-opacity="1"/>
</svg>

Before

Width:  |  Height:  |  Size: 267 B

View File

@ -1,27 +0,0 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "8",
"height": "14",
"viewBox": "0 0 8 14",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M8 14C8 11.7909 6.20914 10 4 10C1.79086 10 0 11.7909 0 14V0C8.05332e-08 2.20914 1.79086 4 4 4C6.20914 4 8 2.20914 8 0V14Z",
"fill": "currentColor",
"fill-opacity": "1"
},
"children": []
}
]
},
"name": "HourglassShape"
}

View File

@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './HourglassShape.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconData } from '@/app/components/base/icons/IconBase'
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>;
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'HourglassShape'
export default Icon

View File

@ -1,7 +1,6 @@
export { default as AnthropicText } from './AnthropicText'
export { default as Generator } from './Generator'
export { default as Group } from './Group'
export { default as HourglassShape } from './HourglassShape'
export { default as Mcp } from './Mcp'
export { default as NoToolPlaceholder } from './NoToolPlaceholder'
export { default as Openai } from './Openai'

View File

@ -31,7 +31,7 @@ export type Item = {
export type ISelectProps = {
className?: string
wrapperClassName?: string
renderTrigger?: (value: Item | null, isOpen: boolean) => React.JSX.Element | null
renderTrigger?: (value: Item | null) => React.JSX.Element | null
items?: Item[]
defaultValue?: number | string
disabled?: boolean
@ -216,7 +216,7 @@ const SimpleSelect: FC<ISelectProps> = ({
>
{({ open }) => (
<div className={classNames('group/simple-select relative h-9', wrapperClassName)}>
{renderTrigger && <ListboxButton className='w-full'>{renderTrigger(selectedItem, open)}</ListboxButton>}
{renderTrigger && <ListboxButton className='w-full'>{renderTrigger(selectedItem)}</ListboxButton>}
{!renderTrigger && (
<ListboxButton onClick={() => {
onOpenChange?.(open)

View File

@ -74,8 +74,7 @@ Chat applications support session persistence, allowing previous chat history to
If set to `false`, can achieve async title generation by calling the conversation rename API and setting `auto_generate` to `true`.
</Property>
<Property name='workflow_id' type='string' key='workflow_id'>
(Optional) Workflow ID to specify a specific version, if not provided, uses the default published version.<br/>
How to obtain: In the version history interface, click the copy icon on the right side of each version entry to copy the complete workflow ID.
(Optional) Workflow ID to specify a specific version, if not provided, uses the default published version.
</Property>
<Property name='trace_id' type='string' key='trace_id'>
(Optional) Trace ID. Used for integration with existing business trace components to achieve end-to-end distributed tracing. If not provided, the system will automatically generate a trace_id. Supports the following three ways to pass, in order of priority:<br/>

View File

@ -74,8 +74,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
`false`に設定すると、会話のリネームAPIを呼び出し、`auto_generate`を`true`に設定することで非同期タイトル生成を実現できます。
</Property>
<Property name='workflow_id' type='string' key='workflow_id'>
オプションワークフローID、特定のバージョンを指定するために使用、提供されない場合はデフォルトの公開バージョンを使用。<br/>
取得方法バージョン履歴インターフェースで、各バージョンエントリの右側にあるコピーアイコンをクリックすると、完全なワークフローIDをコピーできます。
オプションワークフローID、特定のバージョンを指定するために使用、提供されない場合はデフォルトの公開バージョンを使用。
</Property>
<Property name='trace_id' type='string' key='trace_id'>
オプショントレースID。既存の業務システムのトレースコンポーネントと連携し、エンドツーエンドの分散トレーシングを実現するために使用します。指定がない場合、システムが自動的に trace_id を生成します。以下の3つの方法で渡すことができ、優先順位は次のとおりです<br/>

View File

@ -72,8 +72,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
(选填)自动生成标题,默认 `true`。 若设置为 `false`,则可通过调用会话重命名接口并设置 `auto_generate` 为 `true` 实现异步生成标题。
</Property>
<Property name='workflow_id' type='string' key='workflow_id'>
选填工作流ID用于指定特定版本如果不提供则使用默认的已发布版本。<br/>
获取方式:在版本历史界面,点击每个版本条目右侧的复制图标即可复制完整的工作流 ID。
选填工作流ID用于指定特定版本如果不提供则使用默认的已发布版本。
</Property>
<Property name='trace_id' type='string' key='trace_id'>
选填链路追踪ID。适用于与业务系统已有的trace组件打通实现端到端分布式追踪等场景。如果未指定系统会自动生成<code>trace_id</code>。支持以下三种方式传递,具体优先级依次为:<br/>

Some files were not shown because too many files have changed in this diff Show More