mirror of
https://github.com/langgenius/dify.git
synced 2026-05-23 10:29:07 +08:00
Merge remote-tracking branch 'origin/main' into codex/initialize-user-timezone
This commit is contained in:
@ -557,7 +557,7 @@ MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
GRAPH_ENGINE_MIN_WORKERS=3
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
|
||||
@ -761,7 +761,7 @@ class WorkflowConfig(BaseSettings):
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
|
||||
@ -100,6 +100,13 @@ class CheckDependenciesPendingData(BaseModel):
|
||||
|
||||
|
||||
class RagPipelineDslService:
|
||||
"""Import, export, and inspect RAG pipeline DSL using the caller-owned session.
|
||||
|
||||
Controllers wrap this service in a SQLAlchemy transaction context, so methods must only flush interim changes when
|
||||
generated IDs are needed. Committing inside the service would close the caller's transaction and break later work in
|
||||
the same context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
@ -325,7 +332,7 @@ class RagPipelineDslService:
|
||||
type=CollectionBindingType.DATASET,
|
||||
)
|
||||
self._session.add(dataset_collection_binding)
|
||||
self._session.commit()
|
||||
self._session.flush()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
@ -337,7 +344,7 @@ class RagPipelineDslService:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
dataset.pipeline_id = pipeline.id
|
||||
self._session.add(dataset)
|
||||
self._session.commit()
|
||||
self._session.flush()
|
||||
dataset_id = dataset.id
|
||||
if not dataset_id:
|
||||
raise ValueError("DSL is not valid, please check the Knowledge Index node.")
|
||||
@ -462,7 +469,7 @@ class RagPipelineDslService:
|
||||
type=CollectionBindingType.DATASET,
|
||||
)
|
||||
self._session.add(dataset_collection_binding)
|
||||
self._session.commit()
|
||||
self._session.flush()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
@ -474,7 +481,7 @@ class RagPipelineDslService:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
dataset.pipeline_id = pipeline.id
|
||||
self._session.add(dataset)
|
||||
self._session.commit()
|
||||
self._session.flush()
|
||||
dataset_id = dataset.id
|
||||
if not dataset_id:
|
||||
raise ValueError("DSL is not valid, please check the Knowledge Index node.")
|
||||
@ -585,7 +592,7 @@ class RagPipelineDslService:
|
||||
pipeline.id = str(uuid4())
|
||||
|
||||
self._session.add(pipeline)
|
||||
self._session.commit()
|
||||
self._session.flush()
|
||||
# save dependencies
|
||||
if dependencies:
|
||||
redis_client.setex(
|
||||
@ -627,8 +634,8 @@ class RagPipelineDslService:
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
workflow.rag_pipeline_variables = rag_pipeline_variables_list
|
||||
# commit db session changes
|
||||
self._session.commit()
|
||||
# Keep transaction ownership with the caller while materializing IDs and constraint checks before returning.
|
||||
self._session.flush()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@ -297,7 +297,7 @@ class TableTestRunner:
|
||||
max_workers: int = 4,
|
||||
enable_logging: bool = False,
|
||||
log_level: str = "INFO",
|
||||
graph_engine_min_workers: int = 1,
|
||||
graph_engine_min_workers: int = 3,
|
||||
graph_engine_max_workers: int = 1,
|
||||
graph_engine_scale_up_threshold: int = 5,
|
||||
graph_engine_scale_down_idle_time: float = 30.0,
|
||||
@ -310,7 +310,7 @@ class TableTestRunner:
|
||||
max_workers: Maximum number of parallel workers for test execution
|
||||
enable_logging: Enable detailed logging
|
||||
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
graph_engine_min_workers: Minimum workers for GraphEngine (default: 1)
|
||||
graph_engine_min_workers: Minimum workers for GraphEngine (default: 3)
|
||||
graph_engine_max_workers: Maximum workers for GraphEngine (default: 1)
|
||||
graph_engine_scale_up_threshold: Queue depth to trigger scale up
|
||||
graph_engine_scale_down_idle_time: Idle time before scaling down
|
||||
|
||||
@ -259,6 +259,60 @@ workflow:
|
||||
if result.status == ImportStatus.FAILED:
|
||||
print(f"DEBUG: {result.error}")
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
session.commit.assert_not_called()
|
||||
session.flush.assert_called()
|
||||
|
||||
|
||||
def test_import_rag_pipeline_flushes_new_collection_binding_without_commit(mocker) -> None:
|
||||
yaml_content = """
|
||||
version: 0.1.0
|
||||
kind: rag_pipeline
|
||||
rag_pipeline:
|
||||
name: Test Pipeline
|
||||
workflow:
|
||||
graph:
|
||||
nodes:
|
||||
- data:
|
||||
type: knowledge-index
|
||||
"""
|
||||
pipeline = Mock(id="p1", description="desc", is_published=False)
|
||||
pipeline.name = "Test Pipeline"
|
||||
mocker.patch.object(RagPipelineDslService, "_create_or_update_pipeline", return_value=pipeline)
|
||||
|
||||
config_mock = Mock()
|
||||
config_mock.indexing_technique = "high_quality"
|
||||
config_mock.embedding_model = "m"
|
||||
config_mock.embedding_model_provider = "p"
|
||||
config_mock.chunk_structure = "text_model"
|
||||
config_mock.retrieval_model.model_dump.return_value = {}
|
||||
config_mock.summary_index_setting = None
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.KnowledgeConfiguration.model_validate",
|
||||
return_value=config_mock,
|
||||
)
|
||||
|
||||
dataset_mock = Mock(id="d1")
|
||||
binding_mock = Mock(id="b1")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock)
|
||||
binding_cls = mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding",
|
||||
return_value=binding_mock,
|
||||
)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
|
||||
session = cast(MagicMock, Mock())
|
||||
session.scalar.return_value = None
|
||||
session.scalars.return_value.all.return_value = []
|
||||
service = RagPipelineDslService(session=cast(Session, session))
|
||||
account = Mock(current_tenant_id="t1")
|
||||
|
||||
result = service.import_rag_pipeline(account=account, import_mode="yaml-content", yaml_content=yaml_content)
|
||||
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
binding_cls.assert_called_once()
|
||||
assert dataset_mock.collection_binding_id == "b1"
|
||||
session.commit.assert_not_called()
|
||||
assert session.flush.call_count >= 2
|
||||
|
||||
|
||||
def test_import_rag_pipeline_pending_version(mocker) -> None:
|
||||
@ -338,6 +392,67 @@ workflow:
|
||||
assert result.dataset_id == "d1"
|
||||
|
||||
|
||||
def test_confirm_import_flushes_new_collection_binding_without_commit(mocker) -> None:
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelinePendingData
|
||||
|
||||
yaml_content = """
|
||||
version: 0.1.0
|
||||
kind: rag_pipeline
|
||||
rag_pipeline:
|
||||
name: Test Pipeline
|
||||
workflow:
|
||||
graph:
|
||||
nodes:
|
||||
- data:
|
||||
type: knowledge-index
|
||||
"""
|
||||
pending = RagPipelinePendingData(import_mode="yaml-content", yaml_content=yaml_content, pipeline_id="p1")
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.redis_client.get",
|
||||
return_value=pending.model_dump_json(),
|
||||
)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.redis_client.delete")
|
||||
|
||||
pipeline = Mock(id="p1", description="desc")
|
||||
pipeline.name = "Test Pipeline"
|
||||
pipeline.retrieve_dataset.return_value = None
|
||||
mocker.patch.object(RagPipelineDslService, "_create_or_update_pipeline", return_value=pipeline)
|
||||
|
||||
config_mock = Mock()
|
||||
config_mock.indexing_technique = "high_quality"
|
||||
config_mock.embedding_model = "m"
|
||||
config_mock.embedding_model_provider = "p"
|
||||
config_mock.chunk_structure = "text_model"
|
||||
config_mock.retrieval_model.model_dump.return_value = {}
|
||||
config_mock.summary_index_setting = None
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.KnowledgeConfiguration.model_validate",
|
||||
return_value=config_mock,
|
||||
)
|
||||
|
||||
dataset_mock = Mock(id="d1")
|
||||
binding_mock = Mock(id="b1")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock)
|
||||
binding_cls = mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding",
|
||||
return_value=binding_mock,
|
||||
)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock())
|
||||
|
||||
session = cast(MagicMock, Mock())
|
||||
session.scalar.side_effect = [pipeline, None]
|
||||
service = RagPipelineDslService(session=cast(Session, session))
|
||||
account = Mock(id="u1", current_tenant_id="t1")
|
||||
|
||||
result = service.confirm_import(account=account, import_id="imp-1")
|
||||
|
||||
assert result.status == ImportStatus.COMPLETED
|
||||
binding_cls.assert_called_once()
|
||||
assert dataset_mock.collection_binding_id == "b1"
|
||||
session.commit.assert_not_called()
|
||||
assert session.flush.call_count >= 2
|
||||
|
||||
|
||||
# --- _extract_dependencies_from_workflow_graph all types ---
|
||||
|
||||
|
||||
@ -421,6 +536,8 @@ def test_create_or_update_pipeline_create_new(mocker) -> None:
|
||||
|
||||
assert result == pipeline_instance
|
||||
session.add.assert_called()
|
||||
session.commit.assert_not_called()
|
||||
session.flush.assert_called()
|
||||
|
||||
|
||||
# --- export_rag_pipeline_dsl comprehensive ---
|
||||
|
||||
@ -177,7 +177,7 @@ WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
GRAPH_ENGINE_MIN_WORKERS=3
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||
|
||||
Reference in New Issue
Block a user