mirror of
https://github.com/langgenius/dify.git
synced 2026-02-12 22:35:46 +08:00
Compare commits
12 Commits
refactor/w
...
feat/hitl
| Author | SHA1 | Date | |
|---|---|---|---|
| 3d0ff9463f | |||
| b893d2df82 | |||
| 79b6117d80 | |||
| d2ef434dec | |||
| aaf83c2b4c | |||
| d898bcff90 | |||
| b4cf146c85 | |||
| f21782a9a3 | |||
| e4455987e7 | |||
| b2ceb41dd6 | |||
| f614153f30 | |||
| 8ca020e179 |
8
.github/workflows/deploy-hitl.yml
vendored
8
.github/workflows/deploy-hitl.yml
vendored
@ -4,7 +4,8 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "build/feat/hitl"
|
||||
- "feat/hitl-frontend"
|
||||
- "feat/hitl-backend"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@ -13,7 +14,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
(
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
|
||||
)
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v1
|
||||
|
||||
2
.vscode/launch.json.template
vendored
2
.vscode/launch.json.template
vendored
@ -37,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
@ -553,8 +553,6 @@ WORKFLOW_LOG_CLEANUP_ENABLED=false
|
||||
WORKFLOW_LOG_RETENTION_DAYS=30
|
||||
# Batch size for workflow log cleanup operations (default: 100)
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
|
||||
# Comma-separated list of workflow IDs to clean logs for
|
||||
WORKFLOW_LOG_CLEANUP_SPECIFIC_WORKFLOW_IDS=
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
@ -717,7 +715,6 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||
# Sandbox expired records clean configuration
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL=200
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
|
||||
|
||||
|
||||
@ -52,12 +52,14 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
|
||||
@ -102,7 +104,10 @@ forbidden_modules =
|
||||
core.schemas
|
||||
core.tools
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.db.session_factory
|
||||
core.workflow.nodes.agent.agent_node -> models.tools
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
||||
@ -123,6 +128,11 @@ ignore_imports =
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.file.models
|
||||
@ -143,6 +153,7 @@ ignore_imports =
|
||||
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
@ -158,6 +169,9 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
@ -172,9 +186,11 @@ ignore_imports =
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models
|
||||
core.workflow.nodes.list_operator.node -> core.file
|
||||
core.workflow.nodes.llm.file_saver -> core.file
|
||||
core.workflow.nodes.llm.llm_utils -> core.variables.segments
|
||||
core.workflow.nodes.llm.node -> core.file
|
||||
core.workflow.nodes.llm.node -> core.file.file_manager
|
||||
core.workflow.nodes.llm.node -> core.file.models
|
||||
core.workflow.nodes.loop.entities -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file
|
||||
core.workflow.nodes.protocols -> core.file
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models
|
||||
@ -189,14 +205,12 @@ ignore_imports =
|
||||
core.workflow.utils.condition.processor -> core.file.file_manager
|
||||
core.workflow.workflow_entry -> core.file.models
|
||||
core.workflow.workflow_type_encoder -> core.file.models
|
||||
core.workflow.variables.segments -> core.file
|
||||
core.workflow.variables.types -> core.file.models
|
||||
core.workflow.variables.variables -> core.helper.encrypter
|
||||
core.workflow.nodes.agent.agent_node -> models.model
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
|
||||
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
|
||||
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
@ -206,6 +220,7 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
@ -221,15 +236,67 @@ ignore_imports =
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.llm.node -> models.dataset
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.signature
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.errors
|
||||
core.workflow.conversation_variable_updater -> core.variables
|
||||
core.workflow.graph_engine.entities.commands -> core.variables.variables
|
||||
core.workflow.nodes.agent.agent_node -> core.variables.segments
|
||||
core.workflow.nodes.answer.answer_node -> core.variables
|
||||
core.workflow.nodes.code.code_node -> core.variables.segments
|
||||
core.workflow.nodes.code.code_node -> core.variables.types
|
||||
core.workflow.nodes.code.entities -> core.variables.types
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
|
||||
core.workflow.nodes.document_extractor.node -> core.variables
|
||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
core.workflow.nodes.http_request.node -> core.variables.segments
|
||||
core.workflow.nodes.human_input.entities -> core.variables.consts
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
|
||||
core.workflow.nodes.list_operator.node -> core.variables
|
||||
core.workflow.nodes.list_operator.node -> core.variables.segments
|
||||
core.workflow.nodes.llm.node -> core.variables
|
||||
core.workflow.nodes.loop.loop_node -> core.variables
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.exc -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
|
||||
core.workflow.nodes.tool.tool_node -> core.variables.segments
|
||||
core.workflow.nodes.tool.tool_node -> core.variables.variables
|
||||
core.workflow.nodes.trigger_webhook.node -> core.variables.types
|
||||
core.workflow.nodes.trigger_webhook.node -> core.variables.variables
|
||||
core.workflow.nodes.variable_aggregator.entities -> core.variables.types
|
||||
core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
|
||||
core.workflow.nodes.variable_assigner.v1.node -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.node -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
|
||||
core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
|
||||
core.workflow.runtime.read_only_wrappers -> core.variables.segments
|
||||
core.workflow.runtime.variable_pool -> core.variables
|
||||
core.workflow.runtime.variable_pool -> core.variables.consts
|
||||
core.workflow.runtime.variable_pool -> core.variables.segments
|
||||
core.workflow.runtime.variable_pool -> core.variables.variables
|
||||
core.workflow.utils.condition.processor -> core.variables
|
||||
core.workflow.utils.condition.processor -> core.variables.segments
|
||||
core.workflow.variable_loader -> core.variables
|
||||
core.workflow.variable_loader -> core.variables.consts
|
||||
core.workflow.workflow_type_encoder -> core.variables
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
@ -286,6 +353,7 @@ forbidden_modules =
|
||||
core.schemas
|
||||
core.tools
|
||||
core.trigger
|
||||
core.variables
|
||||
core.workflow
|
||||
ignore_imports =
|
||||
core.model_runtime.model_providers.__base.ai_model -> configs
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,workflow_based_app_execution,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@ -122,7 +122,7 @@ These commands assume you start from the repository root.
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
@ -1180,16 +1180,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
# API token last_used_at batch update
|
||||
ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field(
|
||||
description="Enable periodic batch update of API token last_used_at timestamps",
|
||||
default=True,
|
||||
)
|
||||
API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field(
|
||||
description="Interval in minutes for batch updating API token last_used_at (default 30)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
# Trigger provider refresh (simple version)
|
||||
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
|
||||
description="Enable trigger provider refresh poller",
|
||||
@ -1314,9 +1304,6 @@ class WorkflowLogConfig(BaseSettings):
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
|
||||
default=100, description="Batch size for workflow run log cleanup operations"
|
||||
)
|
||||
WORKFLOW_LOG_CLEANUP_SPECIFIC_WORKFLOW_IDS: str = Field(
|
||||
default="", description="Comma-separated list of workflow IDs to clean logs for"
|
||||
)
|
||||
|
||||
|
||||
class SwaggerUIConfig(BaseSettings):
|
||||
@ -1347,10 +1334,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
|
||||
description="Maximum interval in milliseconds between batches",
|
||||
default=200,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
|
||||
@ -259,20 +259,11 @@ class CeleryConfig(DatabaseConfig):
|
||||
description="Password of the Redis Sentinel master.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout for Redis Sentinel socket operations in seconds.",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
CELERY_TASK_ANNOTATIONS: dict[str, Any] | None = Field(
|
||||
description=(
|
||||
"Annotations for Celery tasks as a JSON mapping of task name -> options "
|
||||
"(for example, rate limits or other task-specific settings)."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
if self.CELERY_BACKEND in ("database", "rabbitmq"):
|
||||
|
||||
@ -21,7 +21,6 @@ language_timezone_mapping = {
|
||||
"th-TH": "Asia/Bangkok",
|
||||
"id-ID": "Asia/Jakarta",
|
||||
"ar-TN": "Africa/Tunis",
|
||||
"nl-NL": "Europe/Amsterdam",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
||||
@ -5,6 +5,8 @@ from enum import StrEnum
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
@ -22,9 +24,6 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No
|
||||
|
||||
|
||||
def get_or_create_model(model_name: str, field_def):
|
||||
# Import lazily to avoid circular imports between console controllers and schema helpers.
|
||||
from controllers.console import console_ns
|
||||
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
|
||||
@ -10,7 +10,6 @@ from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@ -132,11 +131,6 @@ class BaseApiKeyResource(Resource):
|
||||
if key is None:
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
|
||||
|
||||
# Invalidate cache before deleting from database
|
||||
# Type assertion: key is guaranteed to be non-None here because abort() raises
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -599,12 +599,7 @@ def _get_conversation(app_model, conversation_id):
|
||||
db.session.execute(
|
||||
sa.update(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
|
||||
# Keep updated_at unchanged when only marking a conversation as read.
|
||||
.values(
|
||||
read_at=naive_utc_now(),
|
||||
read_account_id=current_user.id,
|
||||
updated_at=Conversation.updated_at,
|
||||
)
|
||||
.values(read_at=naive_utc_now(), read_account_id=current_user.id)
|
||||
)
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
@ -16,10 +16,10 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.variables.segment_group import SegmentGroup
|
||||
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
@ -463,9 +463,8 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
"""Console API for getting workflow pause details."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow pause details.
|
||||
@ -478,14 +477,10 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
# Query WorkflowRun to determine if workflow is suspended
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
|
||||
|
||||
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
|
||||
if not workflow_run:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
if workflow_run.tenant_id != current_user.current_tenant_id:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
# Check if workflow is suspended
|
||||
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
if not is_paused:
|
||||
|
||||
@ -55,7 +55,6 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
@ -821,11 +820,6 @@ class DatasetApiDeleteApi(Resource):
|
||||
if key is None:
|
||||
console_ns.abort(404, message="API key not found")
|
||||
|
||||
# Invalidate cache before deleting from database
|
||||
# Type assertion: key is guaranteed to be non-None here because abort() raises
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -21,8 +21,8 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
@ -11,12 +10,12 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.fastopenapi import console_router
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -24,73 +23,69 @@ class RemoteFileUploadPayload(BaseModel):
|
||||
url: str = Field(..., description="URL to fetch")
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/<path:url>")
|
||||
class GetRemoteFileInfo(Resource):
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
@console_router.get(
|
||||
"/remote-files/<path:url>",
|
||||
response_model=RemoteFileInfo,
|
||||
tags=["console"],
|
||||
)
|
||||
def get_remote_file_info(url: str) -> RemoteFileInfo:
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
file_length=int(resp.headers.get("Content-Length", 0)),
|
||||
)
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/remote-files/upload",
|
||||
response_model=FileWithSignedUrl,
|
||||
tags=["console"],
|
||||
status_code=201,
|
||||
)
|
||||
def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
|
||||
url = payload.url
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
file_length=int(resp.headers.get("Content-Length", 0)),
|
||||
).model_dump(mode="json")
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUpload(Resource):
|
||||
@login_required
|
||||
def post(self):
|
||||
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
url = payload.url
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError
|
||||
|
||||
# Try to fetch remote file metadata/content first
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# Normalize into a user-friendly error message expected by tests
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
# Enforce file size limit with 400 (Bad Request) per tests' expectation
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError()
|
||||
|
||||
# Load content if needed
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=user,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# Success: return created resource with 201 status
|
||||
return (
|
||||
FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
).model_dump(mode="json"),
|
||||
201,
|
||||
try:
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=user,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
)
|
||||
|
||||
@ -42,15 +42,7 @@ class SetupResponse(BaseModel):
|
||||
tags=["console"],
|
||||
)
|
||||
def get_setup_status_api() -> SetupStatusResponse:
|
||||
"""Get system setup status.
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design.
|
||||
|
||||
During first-time bootstrap there is no admin account yet, so frontend initialization must be
|
||||
able to query setup progress before any login flow exists.
|
||||
|
||||
Only bootstrap-safe status information should be returned by this endpoint.
|
||||
"""
|
||||
"""Get system setup status."""
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
setup_status = get_setup_status()
|
||||
if setup_status and not isinstance(setup_status, bool):
|
||||
@ -69,12 +61,7 @@ def get_setup_status_api() -> SetupStatusResponse:
|
||||
)
|
||||
@only_edition_self_hosted
|
||||
def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
"""Initialize system setup with admin account.
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design for first-time bootstrap.
|
||||
Access is restricted by deployment mode (`SELF_HOSTED`), one-time setup guards,
|
||||
and init-password validation rather than user session authentication.
|
||||
"""
|
||||
"""Initialize system setup with admin account."""
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
||||
|
||||
@ -120,7 +120,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
|
||||
@ -34,8 +34,6 @@ from .dataset import (
|
||||
metadata,
|
||||
segment,
|
||||
)
|
||||
from .dataset.rag_pipeline import rag_pipeline_workflow
|
||||
from .end_user import end_user
|
||||
from .workspace import models
|
||||
|
||||
__all__ = [
|
||||
@ -46,7 +44,6 @@ __all__ = [
|
||||
"conversation",
|
||||
"dataset",
|
||||
"document",
|
||||
"end_user",
|
||||
"file",
|
||||
"file_preview",
|
||||
"hit_testing",
|
||||
@ -54,7 +51,6 @@ __all__ = [
|
||||
"message",
|
||||
"metadata",
|
||||
"models",
|
||||
"rag_pipeline_workflow",
|
||||
"segment",
|
||||
"site",
|
||||
"workflow",
|
||||
|
||||
@ -396,7 +396,7 @@ class DatasetApi(DatasetApiResource):
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return "", 204
|
||||
return 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
@ -557,7 +557,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/binding")
|
||||
@ -581,7 +581,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
@ -605,7 +605,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/tags")
|
||||
|
||||
@ -746,4 +746,4 @@ class DocumentApi(DatasetApiResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@ -128,7 +128,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
@ -10,7 +12,6 @@ from controllers.common.errors import FilenameNotExistsError, NoFileUploadedErro
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.dataset.rag_pipeline.serializers import serialize_upload_file
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -40,7 +41,7 @@ register_schema_model(service_api_ns, DatasourceNodeRunPayload)
|
||||
register_schema_model(service_api_ns, PipelineRunApiEntity)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource-plugins")
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
|
||||
@ -75,7 +76,7 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
return datasource_plugins, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run")
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
||||
class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@ -130,7 +131,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/run")
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
|
||||
class PipelineRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@ -231,4 +232,12 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return serialize_upload_file(upload_file), 201
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
"""
|
||||
Serialization helpers for Service API knowledge pipeline endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.isoformat() if upload_file.created_at else None,
|
||||
}
|
||||
@ -233,7 +233,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_segment")
|
||||
@ -499,7 +499,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
except ChildChunkDeleteIndexServiceError as e:
|
||||
raise ChildChunkDeleteIndexError(str(e))
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_child_chunk")
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
from . import end_user
|
||||
|
||||
__all__ = ["end_user"]
|
||||
@ -1,41 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.end_user.error import EndUserNotFoundError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from fields.end_user_fields import EndUserDetail
|
||||
from models.model import App
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
@service_api_ns.route("/end-users/<uuid:end_user_id>")
|
||||
class EndUserApi(Resource):
|
||||
"""Resource for retrieving end user details by ID."""
|
||||
|
||||
@service_api_ns.doc("get_end_user")
|
||||
@service_api_ns.doc(description="Get an end user by ID")
|
||||
@service_api_ns.doc(
|
||||
params={"end_user_id": "End user ID"},
|
||||
responses={
|
||||
200: "End user retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "End user not found",
|
||||
},
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, end_user_id: UUID):
|
||||
"""Get end user detail.
|
||||
|
||||
This endpoint is scoped to the current app token's tenant/app to prevent
|
||||
cross-tenant/app access when an end-user ID is known.
|
||||
"""
|
||||
|
||||
end_user = EndUserService.get_end_user_by_id(
|
||||
tenant_id=app_model.tenant_id, app_id=app_model.id, end_user_id=str(end_user_id)
|
||||
)
|
||||
if end_user is None:
|
||||
raise EndUserNotFoundError()
|
||||
|
||||
return EndUserDetail.model_validate(end_user).model_dump(mode="json")
|
||||
@ -1,7 +0,0 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class EndUserNotFoundError(BaseHTTPException):
|
||||
error_code = "end_user_not_found"
|
||||
description = "End user not found."
|
||||
code = 404
|
||||
@ -1,24 +1,27 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -217,8 +220,6 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
@ -255,18 +256,12 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
@ -301,14 +296,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
"""
|
||||
Validate and get API token with Redis caching.
|
||||
|
||||
This function uses a two-tier approach:
|
||||
1. First checks Redis cache for the token
|
||||
2. If not cached, queries database and caches the result
|
||||
|
||||
The last_used_at field is updated asynchronously via Celery task
|
||||
to avoid blocking the request.
|
||||
Validate and get API token.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None or " " not in auth_header:
|
||||
@ -320,18 +308,29 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
# Try to get token from cache first
|
||||
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
|
||||
cached_token = ApiTokenCache.get(auth_token, scope)
|
||||
if cached_token is not None:
|
||||
logger.debug("Token validation served from cache for scope: %s", scope)
|
||||
# Record usage in Redis for later batch update (no Celery task per request)
|
||||
record_token_usage(auth_token, scope)
|
||||
return cast(ApiToken, cached_token)
|
||||
current_time = naive_utc_now()
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
)
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
# Cache miss - use Redis lock for single-flight mode
|
||||
# This ensures only one request queries DB for the same token concurrently
|
||||
return fetch_token_with_single_flight(auth_token, scope)
|
||||
if hasattr(result, "rowcount") and result.rowcount > 0:
|
||||
session.commit()
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
return api_token
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
|
||||
@ -65,12 +65,15 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
# TODO(QuantumGhost): disable authorization for web app
|
||||
# form api temporarily
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>")
|
||||
# class HumanInputFormApi(WebApiResource):
|
||||
class HumanInputFormApi(Resource):
|
||||
"""API for getting and submitting human input forms via the web app."""
|
||||
|
||||
# NOTE(QuantumGhost): this endpoint is unauthenticated on purpose for now.
|
||||
|
||||
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def get(self, form_token: str):
|
||||
"""
|
||||
|
||||
@ -25,6 +25,7 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@ -33,7 +34,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.variables.variables import Variable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -50,6 +50,7 @@ from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import (
|
||||
@ -61,7 +62,6 @@ from core.workflow.enums import (
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
@ -346,7 +346,7 @@ class WorkflowResponseConverter:
|
||||
paused_nodes=list(event.paused_nodes),
|
||||
outputs=encoded_outputs,
|
||||
reasons=pause_reasons,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
status=WorkflowExecutionStatus.PAUSED.value,
|
||||
created_at=int(started_at.timestamp()),
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
@ -422,7 +422,7 @@ class WorkflowResponseConverter:
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=run_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
status=workflow_run.status,
|
||||
status=workflow_run.status.value,
|
||||
outputs=encoded_outputs,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=elapsed_time,
|
||||
|
||||
@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
@ -20,7 +21,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document, Pipeline
|
||||
|
||||
@ -34,7 +34,7 @@ def stream_topic_events(
|
||||
on_subscribe()
|
||||
while True:
|
||||
try:
|
||||
msg = sub.receive(timeout=1)
|
||||
msg = sub.receive(timeout=0.1)
|
||||
except SubscriptionClosedError:
|
||||
return
|
||||
if msg is None:
|
||||
|
||||
@ -262,7 +262,7 @@ class WorkflowPauseStreamResponse(StreamResponse):
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
status: WorkflowExecutionStatus
|
||||
status: str
|
||||
created_at: int
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
|
||||
from core.variables import VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.variables import VariableBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -45,8 +45,6 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.file import helpers as file_helpers
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@ -58,11 +56,10 @@ from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.signature import sign_tool_file
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -466,85 +463,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
def _record_files(self):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
if not message_files:
|
||||
return None
|
||||
|
||||
files_list = []
|
||||
upload_file_ids = [
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
]
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
for message_file in message_files:
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
# Fallback: generate URL even if upload_file not found
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
# For tool files, use URL directly if it's HTTP, otherwise sign it
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
else:
|
||||
# Extract tool file id and extension from URL
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0] # Remove query params first
|
||||
# Use rsplit to correctly handle filenames with multiple dots
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
file_dict = {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
files_list.append(file_dict)
|
||||
return files_list or None
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
"""
|
||||
Agent message to stream response.
|
||||
|
||||
@ -64,13 +64,7 @@ class MessageCycleManager:
|
||||
|
||||
# Use SQLAlchemy 2.x style session.scalar(select(...))
|
||||
with session_factory.create_session() as session:
|
||||
message_file = session.scalar(
|
||||
select(MessageFile)
|
||||
.where(
|
||||
MessageFile.message_id == message_id,
|
||||
)
|
||||
.where(MessageFile.belongs_to == "assistant")
|
||||
)
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
|
||||
|
||||
if message_file:
|
||||
self._message_has_file.add(message_id)
|
||||
|
||||
@ -8,7 +8,6 @@ from core.file.file_manager import file_manager
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
@ -17,7 +16,6 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
@ -77,7 +75,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
@ -143,15 +140,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
rag_retrieval=self._rag_retrieval,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@ -5,7 +5,7 @@ from base64 import b64encode
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.variables.utils import dumps_with_segments
|
||||
from core.variables.utils import dumps_with_segments
|
||||
|
||||
|
||||
class TemplateTransformer(ABC):
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy import and_, literal, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@ -20,7 +18,6 @@ from core.app.app_config.entities import (
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
@ -61,30 +58,12 @@ from core.rag.retrieval.template_prompts import (
|
||||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.repositories.rag_retrieval_protocol import (
|
||||
KnowledgeRetrievalRequest,
|
||||
Source,
|
||||
SourceChildChunk,
|
||||
SourceMetadata,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models import UploadFile
|
||||
from models.dataset import (
|
||||
ChildChunk,
|
||||
Dataset,
|
||||
DatasetMetadata,
|
||||
DatasetQuery,
|
||||
DocumentSegment,
|
||||
RateLimitLog,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.dataset import Document as DocumentModel
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@ -94,8 +73,6 @@ default_retrieval_model: dict[str, Any] = {
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetRetrieval:
|
||||
def __init__(self, application_generate_entity=None):
|
||||
@ -114,233 +91,6 @@ class DatasetRetrieval:
|
||||
else:
|
||||
self._llm_usage = self._llm_usage.plus(usage)
|
||||
|
||||
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
|
||||
self._check_knowledge_rate_limit(request.tenant_id)
|
||||
available_datasets = self._get_available_datasets(request.tenant_id, request.dataset_ids)
|
||||
available_datasets_ids = [i.id for i in available_datasets]
|
||||
if not available_datasets_ids:
|
||||
return []
|
||||
|
||||
if not request.query:
|
||||
return []
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = None, None
|
||||
|
||||
if request.metadata_filtering_mode != "disabled":
|
||||
# Convert workflow layer types to app_config layer types
|
||||
if not request.metadata_model_config:
|
||||
raise ValueError("metadata_model_config is required for this method")
|
||||
|
||||
app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump())
|
||||
|
||||
app_metadata_filtering_conditions = None
|
||||
if request.metadata_filtering_conditions is not None:
|
||||
app_metadata_filtering_conditions = MetadataFilteringCondition.model_validate(
|
||||
request.metadata_filtering_conditions.model_dump()
|
||||
)
|
||||
|
||||
query = request.query if request.query is not None else ""
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
|
||||
dataset_ids=available_datasets_ids,
|
||||
query=query,
|
||||
tenant_id=request.tenant_id,
|
||||
user_id=request.user_id,
|
||||
metadata_filtering_mode=request.metadata_filtering_mode,
|
||||
metadata_model_config=app_metadata_model_config,
|
||||
metadata_filtering_conditions=app_metadata_filtering_conditions,
|
||||
inputs={},
|
||||
)
|
||||
|
||||
if request.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
# Ensure required fields are not None for single retrieval mode
|
||||
if request.model_provider is None or request.model_name is None or request.query is None:
|
||||
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=request.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=request.model_provider,
|
||||
model=request.model_name,
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=request.model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise exc.ModelCredentialsNotInitializedError(
|
||||
f"Model {request.model_name} credentials is not initialized."
|
||||
)
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise exc.ModelNotSupportedError(f"Dify Hosted OpenAI {request.model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise exc.ModelQuotaExceededError(f"Model provider {request.model_provider} quota exceeded.")
|
||||
|
||||
stop = []
|
||||
completion_params = (request.completion_params or {}).copy()
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(request.model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
|
||||
|
||||
model_config = ModelConfigWithCredentialsEntity(
|
||||
provider=request.model_provider,
|
||||
model=request.model_name,
|
||||
model_schema=model_schema,
|
||||
mode=request.model_mode or "chat",
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
all_documents = self.single_retrieve(
|
||||
request.app_id,
|
||||
request.tenant_id,
|
||||
request.user_id,
|
||||
request.user_from,
|
||||
request.query,
|
||||
available_datasets,
|
||||
model_instance,
|
||||
model_config,
|
||||
planning_strategy,
|
||||
None, # message_id
|
||||
metadata_filter_document_ids,
|
||||
metadata_condition,
|
||||
)
|
||||
else:
|
||||
all_documents = self.multiple_retrieve(
|
||||
app_id=request.app_id,
|
||||
tenant_id=request.tenant_id,
|
||||
user_id=request.user_id,
|
||||
user_from=request.user_from,
|
||||
available_datasets=available_datasets,
|
||||
query=request.query,
|
||||
top_k=request.top_k,
|
||||
score_threshold=request.score_threshold,
|
||||
reranking_mode=request.reranking_mode,
|
||||
reranking_model=request.reranking_model,
|
||||
weights=request.weights,
|
||||
reranking_enable=request.reranking_enable,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
attachment_ids=request.attachment_ids,
|
||||
)
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
source = Source(
|
||||
metadata=SourceMetadata(
|
||||
source="knowledge",
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from="workflow",
|
||||
score=item.metadata.get("score"),
|
||||
doc_metadata=item.metadata,
|
||||
),
|
||||
title=item.metadata.get("title"),
|
||||
content=item.page_content,
|
||||
)
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
records = RetrievalService.format_retrieval_documents(dify_documents)
|
||||
dataset_ids = [i.segment.dataset_id for i in records]
|
||||
document_ids = [i.segment.document_id for i in records]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
|
||||
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
|
||||
|
||||
dataset_map = {i.id: i for i in datasets}
|
||||
document_map = {i.id: i for i in documents}
|
||||
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = dataset_map.get(segment.dataset_id)
|
||||
document = document_map.get(segment.document_id)
|
||||
|
||||
if dataset and document:
|
||||
source = Source(
|
||||
metadata=SourceMetadata(
|
||||
source="knowledge",
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from="workflow",
|
||||
score=record.score or 0.0,
|
||||
segment_hit_count=segment.hit_count,
|
||||
segment_word_count=segment.word_count,
|
||||
segment_position=segment.position,
|
||||
segment_index_node_hash=segment.index_node_hash,
|
||||
doc_metadata=document.doc_metadata,
|
||||
child_chunks=[
|
||||
SourceChildChunk(
|
||||
id=str(getattr(chunk, "id", "")),
|
||||
content=str(getattr(chunk, "content", "")),
|
||||
position=int(getattr(chunk, "position", 0)),
|
||||
score=float(getattr(chunk, "score", 0.0)),
|
||||
)
|
||||
for chunk in (record.child_chunks or [])
|
||||
],
|
||||
position=None,
|
||||
),
|
||||
title=document.name,
|
||||
files=list(record.files) if record.files else None,
|
||||
content=segment.get_sign_content(),
|
||||
)
|
||||
if segment.answer:
|
||||
source.content = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
|
||||
if record.summary:
|
||||
source.summary = record.summary
|
||||
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if retrieval_resource_list:
|
||||
|
||||
def _score(item: Source) -> float:
|
||||
meta = item.metadata
|
||||
score = meta.score
|
||||
if isinstance(score, (int, float)):
|
||||
return float(score)
|
||||
return 0.0
|
||||
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=_score, # type: ignore[arg-type, return-value]
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item.metadata.position = position # type: ignore[index]
|
||||
return retrieval_resource_list
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
app_id: str,
|
||||
@ -400,7 +150,14 @@ class DatasetRetrieval:
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
available_datasets = self._get_available_datasets(tenant_id, dataset_ids)
|
||||
available_datasets = []
|
||||
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||
for dataset in datasets:
|
||||
if dataset.available_document_count == 0 and dataset.provider != "external":
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
|
||||
if inputs:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
@ -1404,6 +1161,7 @@ class DatasetRetrieval:
|
||||
query=query or "",
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
invoke_result = cast(
|
||||
@ -1434,8 +1192,7 @@ class DatasetRetrieval:
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
except Exception:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
@ -1649,12 +1406,7 @@ class DatasetRetrieval:
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
if isinstance(text, str):
|
||||
full_text += text
|
||||
elif isinstance(text, list):
|
||||
for i in text:
|
||||
if i.data:
|
||||
full_text += i.data
|
||||
full_text += text
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@ -1772,53 +1524,3 @@ class DatasetRetrieval:
|
||||
cancel_event.set()
|
||||
if thread_exceptions is not None:
|
||||
thread_exceptions.append(e)
|
||||
|
||||
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
|
||||
with session_factory.create_session() as session:
|
||||
subquery = (
|
||||
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
|
||||
.where(
|
||||
DocumentModel.indexing_status == "completed",
|
||||
DocumentModel.enabled == True,
|
||||
DocumentModel.archived == False,
|
||||
DocumentModel.dataset_id.in_(dataset_ids),
|
||||
)
|
||||
.group_by(DocumentModel.dataset_id)
|
||||
.having(func.count(DocumentModel.id) > 0)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
results = (
|
||||
session.query(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
|
||||
available_datasets = []
|
||||
for dataset in results:
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
return available_datasets
|
||||
|
||||
def _check_knowledge_rate_limit(self, tenant_id: str):
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with session_factory.create_session() as session:
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
raise exc.RateLimitExceededError(
|
||||
"you have reached the knowledge base request rate limit of your subscription."
|
||||
)
|
||||
|
||||
@ -5,7 +5,7 @@ from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
if TYPE_CHECKING:
|
||||
from models.model import File
|
||||
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@ -171,7 +171,7 @@ class Tool(ABC):
|
||||
def create_file_message(self, file: File) -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.FILE,
|
||||
message=ToolInvokeMessage.FileMessage(file_marker="file_marker"),
|
||||
message=ToolInvokeMessage.FileMessage(),
|
||||
meta={"file": file},
|
||||
)
|
||||
|
||||
|
||||
@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
@ -17,7 +17,6 @@ from core.mcp.types import (
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
@ -47,7 +46,6 @@ class MCPTool(Tool):
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
@ -61,10 +59,6 @@ class MCPTool(Tool):
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
|
||||
# Extract usage metadata from MCP protocol's _meta field
|
||||
self._latest_usage = self._derive_usage_from_result(result)
|
||||
|
||||
# handle dify tool output
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
@ -126,99 +120,6 @@ class MCPTool(Tool):
|
||||
for item in json_list:
|
||||
yield self.create_json_message(item)
|
||||
|
||||
@property
|
||||
def latest_usage(self) -> LLMUsage:
|
||||
return self._latest_usage
|
||||
|
||||
@classmethod
|
||||
def _derive_usage_from_result(cls, result: CallToolResult) -> LLMUsage:
|
||||
"""
|
||||
Extract usage metadata from MCP tool result's _meta field.
|
||||
|
||||
The MCP protocol's _meta field (aliased as 'meta' in Python) can contain
|
||||
usage information such as token counts, costs, and other metadata.
|
||||
|
||||
Args:
|
||||
result: The CallToolResult from MCP tool invocation
|
||||
|
||||
Returns:
|
||||
LLMUsage instance with values from meta or empty_usage if not found
|
||||
"""
|
||||
# Extract usage from the meta field if present
|
||||
if result.meta:
|
||||
usage_dict = cls._extract_usage_dict(result.meta)
|
||||
if usage_dict is not None:
|
||||
return LLMUsage.from_metadata(cast(LLMUsageMetadata, cast(object, dict(usage_dict))))
|
||||
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Recursively search for usage dictionary in the payload.
|
||||
|
||||
The MCP protocol's _meta field can contain usage data in various formats:
|
||||
- Direct usage field: {"usage": {...}}
|
||||
- Nested in metadata: {"metadata": {"usage": {...}}}
|
||||
- Or nested within other fields
|
||||
|
||||
Args:
|
||||
payload: The payload to search for usage data
|
||||
|
||||
Returns:
|
||||
The usage dictionary if found, None otherwise
|
||||
"""
|
||||
# Check for direct usage field
|
||||
usage_candidate = payload.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
# Check for metadata nested usage
|
||||
metadata_candidate = payload.get("metadata")
|
||||
if isinstance(metadata_candidate, Mapping):
|
||||
usage_candidate = metadata_candidate.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
# Check for common token counting fields directly in payload
|
||||
# Some MCP servers may include token counts directly
|
||||
if "total_tokens" in payload or "prompt_tokens" in payload or "completion_tokens" in payload:
|
||||
usage_dict: dict[str, Any] = {}
|
||||
for key in (
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
"prompt_unit_price",
|
||||
"completion_unit_price",
|
||||
"total_price",
|
||||
"currency",
|
||||
"prompt_price_unit",
|
||||
"completion_price_unit",
|
||||
"prompt_price",
|
||||
"completion_price",
|
||||
"latency",
|
||||
"time_to_first_token",
|
||||
"time_to_generate",
|
||||
):
|
||||
if key in payload:
|
||||
usage_dict[key] = payload[key]
|
||||
if usage_dict:
|
||||
return usage_dict
|
||||
|
||||
# Recursively search through nested structures
|
||||
for value in payload.values():
|
||||
if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from typing import Protocol
|
||||
|
||||
from core.workflow.variables import VariableBase
|
||||
from core.variables import VariableBase
|
||||
|
||||
|
||||
class ConversationVariableUpdater(Protocol):
|
||||
|
||||
@ -10,7 +10,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.runtime.graph_runtime_state import GraphExecutionProtocol
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
@ -237,6 +236,3 @@ class GraphExecution:
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
|
||||
|
||||
_: GraphExecutionProtocol = GraphExecution(workflow_id="")
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.variables.variables import Variable
|
||||
from core.variables.variables import Variable
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.db.session_factory import session_factory
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@ -26,6 +27,7 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@ -43,12 +45,17 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayFileSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from models.tools import (
|
||||
ApiToolProvider,
|
||||
BuiltinToolProvider,
|
||||
MCPToolProvider,
|
||||
WorkflowToolProvider,
|
||||
)
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exc import (
|
||||
@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
provider_type = self._infer_tool_provider_type(tool, self.tenant_id)
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]):
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType:
|
||||
provider_type_str = tool_config.get("type")
|
||||
if provider_type_str:
|
||||
return ToolProviderType(provider_type_str)
|
||||
|
||||
provider_id = tool_config.get("provider_name")
|
||||
if not provider_id:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
provider_map: dict[
|
||||
type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]],
|
||||
ToolProviderType,
|
||||
] = {
|
||||
WorkflowToolProvider: ToolProviderType.WORKFLOW,
|
||||
MCPToolProvider: ToolProviderType.MCP,
|
||||
ApiToolProvider: ToolProviderType.API,
|
||||
BuiltinToolProvider: ToolProviderType.BUILT_IN,
|
||||
}
|
||||
|
||||
for provider_model, provider_type in provider_map.items():
|
||||
stmt = select(provider_model).where(
|
||||
provider_model.id == provider_id,
|
||||
provider_model.tenant_id == tenant_id,
|
||||
)
|
||||
if session.scalar(stmt):
|
||||
return provider_type
|
||||
|
||||
raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.")
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.variables import ArrayFileSegment, FileSegment, Segment
|
||||
|
||||
|
||||
class AnswerNode(Node[AnswerNodeData]):
|
||||
|
||||
@ -6,13 +6,13 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.variables.segments import ArrayFileSegment
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
|
||||
@ -3,9 +3,9 @@ from typing import Annotated, Literal
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
|
||||
@ -17,6 +17,8 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
|
||||
from core.file import File
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
@ -24,8 +26,6 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayAnySegment
|
||||
from core.workflow.variables.variables import ArrayAnyVariable
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
|
||||
@ -23,11 +23,11 @@ from docx.text.paragraph import Paragraph
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayStringSegment, FileSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.variables import ArrayFileSegment
|
||||
from core.workflow.variables.segments import ArrayStringSegment, FileSegment
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
@ -14,8 +14,8 @@ from configs import dify_config
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.file.file_manager import file_manager as default_file_manager
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
from ..protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from .entities import (
|
||||
|
||||
@ -8,6 +8,7 @@ from core.file import File, FileTransferMethod
|
||||
from core.file.file_manager import file_manager as default_file_manager
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
@ -15,7 +16,6 @@ from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from core.workflow.variables.segments import ArrayFileSegment
|
||||
from factories import file_factory
|
||||
|
||||
from .entities import (
|
||||
|
||||
@ -10,10 +10,10 @@ from typing import Annotated, Any, ClassVar, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.consts import SELECTORS_LENGTH
|
||||
|
||||
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit
|
||||
|
||||
|
||||
@ -7,6 +7,9 @@ from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
@ -33,9 +36,6 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables import IntegerVariable, NoneSegment
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.workflow.variables.variables import Variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .exc import (
|
||||
|
||||
@ -20,7 +20,3 @@ class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
|
||||
|
||||
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
|
||||
|
||||
class RateLimitExceededError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the rate limit is exceeded."""
|
||||
|
||||
@ -1,32 +1,70 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
FileSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
||||
from core.workflow.variables import (
|
||||
ArrayFileSegment,
|
||||
FileSegment,
|
||||
StringSegment,
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||
METADATA_FILTER_COMPLETION_PROMPT,
|
||||
METADATA_FILTER_SYSTEM_PROMPT,
|
||||
METADATA_FILTER_USER_PROMPT_1,
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.workflow.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
InvalidModelTypeError,
|
||||
KnowledgeRetrievalNodeError,
|
||||
RateLimitExceededError,
|
||||
ModelCredentialsNotInitializedError,
|
||||
ModelNotExistError,
|
||||
ModelNotSupportedError,
|
||||
ModelQuotaExceededError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -35,6 +73,14 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
@ -51,7 +97,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@ -63,7 +108,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@ -77,7 +121,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
usage = LLMUsage.empty_usage()
|
||||
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -85,7 +128,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
process_data={},
|
||||
outputs={},
|
||||
metadata={},
|
||||
llm_usage=usage,
|
||||
llm_usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
variables: dict[str, Any] = {}
|
||||
# extract variables
|
||||
@ -113,9 +156,36 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
else:
|
||||
variables["attachments"] = [variable.value]
|
||||
|
||||
# TODO(-LAN-): Move this check outside.
|
||||
# check rate limit
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{self.tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=self.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||
error_type="RateLimitExceeded",
|
||||
)
|
||||
|
||||
# retrieve knowledge
|
||||
usage = LLMUsage.empty_usage()
|
||||
try:
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@ -128,17 +198,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
except RateLimitExceededError as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except KnowledgeRetrievalNodeError as e:
|
||||
logger.warning("Error when running knowledge retrieval node", exc_info=True)
|
||||
logger.warning("Error when running knowledge retrieval node")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
@ -148,7 +210,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
@ -156,47 +217,92 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
|
||||
) -> tuple[list[Source], LLMUsage]:
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
available_datasets = []
|
||||
dataset_ids = node_data.dataset_ids
|
||||
query = variables.get("query")
|
||||
attachments = variables.get("attachments")
|
||||
retrieval_resource_list = []
|
||||
metadata_filter_document_ids = None
|
||||
metadata_condition = None
|
||||
metadata_usage = LLMUsage.empty_usage()
|
||||
# Subquery: Count the number of available documents for each dataset
|
||||
subquery = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
)
|
||||
.group_by(Document.dataset_id)
|
||||
.having(func.count(Document.id) > 0)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = "disabled"
|
||||
if node_data.metadata_filtering_mode is not None:
|
||||
metadata_filtering_mode = node_data.metadata_filtering_mode
|
||||
results = (
|
||||
db.session.query(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
|
||||
# avoid blocking at retrieval
|
||||
db.session.close()
|
||||
|
||||
for dataset in results:
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
if query:
|
||||
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||
[dataset.id for dataset in available_datasets], query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, metadata_usage)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required for single retrieval mode")
|
||||
model = node_data.single_retrieval_config.model
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model, credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if model_schema:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
all_documents = dataset_retrieval.single_retrieve(
|
||||
available_datasets=available_datasets,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_from=self.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
|
||||
completion_params=model.completion_params,
|
||||
model_provider=model.provider,
|
||||
model_mode=model.mode,
|
||||
model_name=model.name,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
query=query,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
planning_strategy=planning_strategy,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
)
|
||||
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
reranking_model = None
|
||||
weights = None
|
||||
match node_data.multiple_retrieval_config.reranking_mode:
|
||||
case "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
@ -223,36 +329,284 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
},
|
||||
}
|
||||
case _:
|
||||
# Handle any other reranking_mode values
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
available_datasets=available_datasets,
|
||||
query=query,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold
|
||||
if node_data.multiple_retrieval_config.score_threshold is not None
|
||||
else 0.0,
|
||||
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||
)
|
||||
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
||||
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
query=query,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold
|
||||
if node_data.multiple_retrieval_config.score_threshold is not None
|
||||
else 0.0,
|
||||
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": "workflow",
|
||||
"score": item.metadata.get("score"),
|
||||
"doc_metadata": item.metadata,
|
||||
},
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
records = RetrievalService.format_retrieval_documents(dify_documents)
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||
stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
document = db.session.scalar(stmt)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": "workflow",
|
||||
"score": record.score or 0.0,
|
||||
"child_chunks": [
|
||||
{
|
||||
"id": str(getattr(chunk, "id", "")),
|
||||
"content": str(getattr(chunk, "content", "")),
|
||||
"position": int(getattr(chunk, "position", 0)),
|
||||
"score": float(getattr(chunk, "score", 0.0)),
|
||||
}
|
||||
for chunk in (record.child_chunks or [])
|
||||
],
|
||||
"segment_hit_count": segment.hit_count,
|
||||
"segment_word_count": segment.word_count,
|
||||
"segment_position": segment.position,
|
||||
"segment_index_node_hash": segment.index_node_hash,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
},
|
||||
"title": document.name,
|
||||
"files": list(record.files) if record.files else None,
|
||||
}
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.get_sign_content()
|
||||
# Add summary if available
|
||||
if record.summary:
|
||||
source["summary"] = record.summary
|
||||
retrieval_resource_list.append(source)
|
||||
if retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=self._score, # type: ignore[arg-type, return-value]
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["metadata"]["position"] = position # type: ignore[index]
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
def _score(self, item: dict[str, Any]) -> float:
|
||||
meta = item.get("metadata")
|
||||
if isinstance(meta, dict):
|
||||
s = meta.get("score")
|
||||
if isinstance(s, (int, float)):
|
||||
return float(s)
|
||||
return 0.0
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
document_query = db.session.query(Document).where(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
match node_data.metadata_filtering_mode:
|
||||
case "disabled":
|
||||
return None, None, usage
|
||||
case "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
case "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
case _:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||
):
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.where(or_(*filters))
|
||||
documents = document_query.all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
for document in documents:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||
return metadata_filter_document_ids, metadata_condition, usage
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
# get all metadata field
|
||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
if node_data.metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance and fetch model config
|
||||
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
|
||||
# fetch prompt messages
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
metadata_fields=all_metadata_fields,
|
||||
query=query or "",
|
||||
)
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query=query,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
sys_files=[],
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.metadata_model_config,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
)
|
||||
|
||||
usage = self._rag_retrieval.llm_usage
|
||||
return retrieval_resource_list, usage
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = self._merge_usage(usage, event.usage)
|
||||
break
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
if "metadata_map" in result_text_json:
|
||||
metadata_map = result_text_json["metadata_map"]
|
||||
for item in metadata_map:
|
||||
if item.get("metadata_field_name") in all_metadata_fields:
|
||||
automatic_metadata_filters.append(
|
||||
{
|
||||
"metadata_name": item.get("metadata_field_name"),
|
||||
"value": item.get("metadata_field_value"),
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return [], usage
|
||||
return automatic_metadata_filters, usage
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@ -272,3 +626,107 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
if typed_node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
|
||||
return variable_mapping
|
||||
|
||||
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = model.completion_params
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model.mode
|
||||
if not model_mode:
|
||||
raise ModelNotExistError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
||||
model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
|
||||
input_text = query
|
||||
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=METADATA_FILTER_USER_PROMPT_3.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
|
||||
|
||||
@ -2,11 +2,11 @@ from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
|
||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
@ -14,10 +14,10 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import Conversation
|
||||
|
||||
@ -49,6 +49,14 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
@ -69,14 +77,6 @@ from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
|
||||
@ -3,9 +3,9 @@ from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
[
|
||||
|
||||
@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import Segment, SegmentType
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@ -30,7 +31,6 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from core.workflow.variables import Segment, SegmentType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@ -8,9 +8,9 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
_OLD_BOOL_TYPE_NAME = "bool"
|
||||
_OLD_SELECT_TYPE_NAME = "select"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class ParameterExtractorNodeError(ValueError):
|
||||
|
||||
@ -26,13 +26,13 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.types import ArrayValidation, SegmentType
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
|
||||
@ -12,6 +12,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@ -21,8 +23,6 @@ from core.workflow.enums import (
|
||||
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.workflow.variables.variables import ArrayAnyVariable
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
|
||||
@ -3,13 +3,13 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file import FileTransferMethod
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import FileVariable
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from core.workflow.variables.variables import FileVariable
|
||||
from factories import file_factory
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
|
||||
class AdvancedSettings(BaseModel):
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorNodeData
|
||||
from core.workflow.variables.segments import Segment
|
||||
|
||||
|
||||
class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
|
||||
|
||||
@ -3,9 +3,9 @@ from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.variables import Segment
|
||||
from core.workflow.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from core.variables import Segment
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
# Use double underscore (`__`) prefix for internal variables
|
||||
# to minimize risk of collision with user-defined variable names.
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -8,7 +9,6 @@ from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.variables import SegmentType, VariableBase
|
||||
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.variables import SegmentType
|
||||
from core.variables import SegmentType
|
||||
|
||||
from .enums import Operation
|
||||
|
||||
|
||||
@ -2,14 +2,14 @@ import json
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.variables import SegmentType, VariableBase
|
||||
from core.workflow.variables.consts import SELECTORS_LENGTH
|
||||
|
||||
from . import helpers
|
||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||
|
||||
@ -1,108 +0,0 @@
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities import LLMUsage
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import MetadataFilteringCondition
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
|
||||
|
||||
class SourceChildChunk(BaseModel):
|
||||
id: str = Field(default="", description="Child chunk ID")
|
||||
content: str = Field(default="", description="Child chunk content")
|
||||
position: int = Field(default=0, description="Child chunk position")
|
||||
score: float = Field(default=0.0, description="Child chunk relevance score")
|
||||
|
||||
|
||||
class SourceMetadata(BaseModel):
|
||||
source: str = Field(
|
||||
default="knowledge",
|
||||
serialization_alias="_source",
|
||||
description="Data source identifier",
|
||||
)
|
||||
dataset_id: str = Field(description="Dataset unique identifier")
|
||||
dataset_name: str = Field(description="Dataset display name")
|
||||
document_id: str = Field(description="Document unique identifier")
|
||||
document_name: str = Field(description="Document display name")
|
||||
data_source_type: str = Field(description="Type of data source")
|
||||
segment_id: str | None = Field(default=None, description="Segment unique identifier")
|
||||
retriever_from: str = Field(default="workflow", description="Retriever source context")
|
||||
score: float = Field(default=0.0, description="Retrieval relevance score")
|
||||
child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks")
|
||||
segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved")
|
||||
segment_word_count: int | None = Field(default=0, description="Word count of the segment")
|
||||
segment_position: int | None = Field(default=0, description="Position of segment in document")
|
||||
segment_index_node_hash: str | None = Field(default=None, description="Hash of index node for the segment")
|
||||
doc_metadata: dict[str, Any] | None = Field(default=None, description="Additional document metadata")
|
||||
position: int | None = Field(default=0, description="Position of the document in the dataset")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class Source(BaseModel):
|
||||
metadata: SourceMetadata = Field(description="Source metadata information")
|
||||
title: str = Field(description="Document title")
|
||||
files: list[Any] | None = Field(default=None, description="Associated file references")
|
||||
content: str | None = Field(description="Segment content text")
|
||||
summary: str | None = Field(default=None, description="Content summary if available")
|
||||
|
||||
|
||||
class KnowledgeRetrievalRequest(BaseModel):
|
||||
tenant_id: str = Field(description="Tenant unique identifier")
|
||||
user_id: str = Field(description="User unique identifier")
|
||||
app_id: str = Field(description="Application unique identifier")
|
||||
user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')")
|
||||
dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from")
|
||||
query: str | None = Field(default=None, description="Query text for knowledge retrieval")
|
||||
retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'")
|
||||
model_provider: str | None = Field(default=None, description="Model provider name (e.g., 'openai', 'anthropic')")
|
||||
completion_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Model completion parameters (e.g., temperature, max_tokens)"
|
||||
)
|
||||
model_mode: str | None = Field(default=None, description="Model mode (e.g., 'chat', 'completion')")
|
||||
model_name: str | None = Field(default=None, description="Model name (e.g., 'gpt-4', 'claude-3-opus')")
|
||||
metadata_model_config: ModelConfig | None = Field(
|
||||
default=None, description="Model config for metadata-based filtering"
|
||||
)
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = Field(
|
||||
default=None, description="Conditions for filtering by metadata"
|
||||
)
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = Field(
|
||||
default="disabled", description="Metadata filtering mode: 'disabled', 'automatic', or 'manual'"
|
||||
)
|
||||
top_k: int = Field(default=0, description="Number of top results to return")
|
||||
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
|
||||
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
|
||||
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
|
||||
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
|
||||
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
|
||||
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
|
||||
|
||||
|
||||
class RAGRetrievalProtocol(Protocol):
|
||||
"""Protocol for RAG-based knowledge retrieval implementations.
|
||||
|
||||
Implementations of this protocol handle knowledge retrieval from datasets
|
||||
including rate limiting, dataset filtering, and document retrieval.
|
||||
"""
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Return accumulated LLM usage for retrieval operations."""
|
||||
...
|
||||
|
||||
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
|
||||
"""Retrieve knowledge from datasets based on the provided request.
|
||||
|
||||
Args:
|
||||
request: Knowledge retrieval request with search parameters
|
||||
|
||||
Returns:
|
||||
List of sources matching the search criteria
|
||||
|
||||
Raises:
|
||||
RateLimitExceededError: If rate limit is exceeded
|
||||
ModelNotExistError: If specified model doesn't exist
|
||||
"""
|
||||
...
|
||||
@ -64,7 +64,7 @@ class GraphExecutionProtocol(Protocol):
|
||||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
pause_reasons: Sequence[PauseReason]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
@ -446,7 +446,7 @@ class GraphRuntimeState:
|
||||
graph_execution_cls = module.GraphExecution
|
||||
workflow_id = self._pending_graph_execution_workflow_id or ""
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type]
|
||||
return graph_execution_cls(workflow_id=workflow_id)
|
||||
|
||||
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
|
||||
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
|
||||
|
||||
@ -2,8 +2,8 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
from core.workflow.variables.segments import Segment
|
||||
|
||||
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
|
||||
@ -5,8 +5,8 @@ from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.system_variable import SystemVariableReadOnlyView
|
||||
from core.workflow.variables.segments import Segment
|
||||
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
@ -9,6 +9,10 @@ from typing import Annotated, Any, Union, cast
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, ObjectSegment
|
||||
from core.variables.variables import RAGPipelineVariableInput, Variable
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
@ -16,10 +20,6 @@ from core.workflow.constants import (
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variables import Segment, SegmentGroup, VariableBase
|
||||
from core.workflow.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.variables.segments import FileSegment, ObjectSegment
|
||||
from core.workflow.variables.variables import RAGPipelineVariableInput, Variable
|
||||
from factories import variable_factory
|
||||
|
||||
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
|
||||
|
||||
@ -3,9 +3,9 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
from core.file import FileAttribute, file_manager
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables import ArrayFileSegment
|
||||
from core.workflow.variables.segments import ArrayBooleanSegment, BooleanSegment
|
||||
|
||||
from .entities import Condition, SubCondition, SupportedComparisonOperator
|
||||
|
||||
|
||||
@ -2,9 +2,9 @@ import abc
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.variables import VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables import VariableBase
|
||||
from core.workflow.variables.consts import SELECTORS_LENGTH
|
||||
|
||||
|
||||
class VariableLoader(Protocol):
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Any, overload
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.models import File
|
||||
from core.workflow.variables import Segment
|
||||
from core.variables import Segment
|
||||
|
||||
|
||||
class WorkflowRuntimeTypeConverter:
|
||||
|
||||
@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
||||
@ -80,14 +80,8 @@ def init_app(app: DifyApp) -> Celery:
|
||||
worker_hijack_root_logger=False,
|
||||
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
|
||||
task_ignore_result=True,
|
||||
task_annotations=dify_config.CELERY_TASK_ANNOTATIONS,
|
||||
)
|
||||
|
||||
if dify_config.CELERY_BACKEND == "redis":
|
||||
celery_app.conf.update(
|
||||
result_backend_transport_options=broker_transport_options,
|
||||
)
|
||||
|
||||
# Apply SSL configuration if enabled
|
||||
ssl_options = _get_celery_ssl_options()
|
||||
if ssl_options:
|
||||
@ -196,14 +190,6 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
|
||||
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
|
||||
}
|
||||
|
||||
if dify_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK:
|
||||
imports.append("schedule.update_api_token_last_used_task")
|
||||
beat_schedule["batch_update_api_token_last_used"] = {
|
||||
"task": "schedule.update_api_token_last_used_task.batch_update_api_token_last_used",
|
||||
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
|
||||
}
|
||||
|
||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||
|
||||
return celery_app
|
||||
|
||||
@ -119,7 +119,7 @@ class RedisClientWrapper:
|
||||
|
||||
|
||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
_pubsub_redis_client: redis.Redis | RedisCluster | None = None
|
||||
pubsub_redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
|
||||
|
||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||
@ -232,7 +232,7 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
||||
return client
|
||||
|
||||
|
||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster:
|
||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]:
|
||||
if use_clusters:
|
||||
return RedisCluster.from_url(pubsub_url)
|
||||
return redis.Redis.from_url(pubsub_url)
|
||||
@ -256,19 +256,23 @@ def init_app(app: DifyApp):
|
||||
redis_client.initialize(client)
|
||||
app.extensions["redis"] = redis_client
|
||||
|
||||
global _pubsub_redis_client
|
||||
_pubsub_redis_client = client
|
||||
pubsub_client = client
|
||||
if dify_config.normalized_pubsub_redis_url:
|
||||
_pubsub_redis_client = _create_pubsub_client(
|
||||
pubsub_client = _create_pubsub_client(
|
||||
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
|
||||
)
|
||||
pubsub_redis_client.initialize(pubsub_client)
|
||||
|
||||
|
||||
def get_pubsub_redis_client() -> RedisClientWrapper:
|
||||
return pubsub_redis_client
|
||||
|
||||
|
||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
||||
redis_conn = get_pubsub_redis_client()
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||
return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
||||
return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import ParamSpec, TypeVar, cast
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
from opentelemetry.trace import get_tracer
|
||||
|
||||
@ -8,8 +8,7 @@ from configs import dify_config
|
||||
from extensions.otel.decorators.handler import SpanHandler
|
||||
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T", bound=Callable[..., Any])
|
||||
|
||||
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
|
||||
|
||||
@ -21,7 +20,7 @@ def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
|
||||
return _HANDLER_INSTANCES[handler_class]
|
||||
|
||||
|
||||
def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], T]:
|
||||
"""
|
||||
Decorator that traces a function with an OpenTelemetry span.
|
||||
|
||||
@ -31,9 +30,9 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Call
|
||||
:param handler_class: Optional handler class to use for this span. If None, uses the default SpanHandler.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
def decorator(func: T) -> T:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -47,6 +46,6 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Call
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
return cast(T, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import inspect
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class SpanHandler:
|
||||
"""
|
||||
@ -33,9 +31,9 @@ class SpanHandler:
|
||||
|
||||
def _extract_arguments(
|
||||
self,
|
||||
wrapped: Callable[..., R],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
wrapped: Callable[..., Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Extract function arguments using inspect.signature.
|
||||
@ -64,10 +62,10 @@ class SpanHandler:
|
||||
def wrapper(
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., R],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> R:
|
||||
wrapped: Callable[..., Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Fully control the wrapper behavior.
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
@ -12,19 +12,16 @@ from models.model import Account
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class AppGenerateHandler(SpanHandler):
|
||||
"""Span handler for ``AppGenerateService.generate``."""
|
||||
|
||||
def wrapper(
|
||||
self,
|
||||
tracer: Any,
|
||||
wrapped: Callable[..., R],
|
||||
args: tuple[object, ...],
|
||||
kwargs: Mapping[str, object],
|
||||
) -> R:
|
||||
wrapped: Callable[..., Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> Any:
|
||||
try:
|
||||
arguments = self._extract_arguments(wrapped, args, kwargs)
|
||||
if not arguments:
|
||||
|
||||
@ -10,10 +10,10 @@ from opentelemetry.trace.status import Status, StatusCode
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.models import File
|
||||
from core.variables import Segment
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.variables import Segment
|
||||
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
|
||||
|
||||
|
||||
|
||||
@ -8,9 +8,9 @@ from typing import Any
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from core.variables import Segment
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.variables import Segment
|
||||
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
|
||||
from extensions.otel.semconv.gen_ai import RetrieverAttributes
|
||||
|
||||
|
||||
@ -4,12 +4,8 @@ from uuid import uuid4
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.variables.exc import VariableError
|
||||
from core.workflow.variables.segments import (
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayBooleanSegment,
|
||||
ArrayFileSegment,
|
||||
@ -26,8 +22,8 @@ from core.workflow.variables.segments import (
|
||||
Segment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from core.workflow.variables.variables import (
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import (
|
||||
ArrayAnyVariable,
|
||||
ArrayBooleanVariable,
|
||||
ArrayFileVariable,
|
||||
@ -44,6 +40,10 @@ from core.workflow.variables.variables import (
|
||||
StringVariable,
|
||||
VariableBase,
|
||||
)
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedSegmentTypeError(Exception):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user