diff --git a/api/.env.example b/api/.env.example index 40fed7403c..34be400e87 100644 --- a/api/.env.example +++ b/api/.env.example @@ -557,7 +557,7 @@ MAX_VARIABLE_SIZE=204800 # GraphEngine Worker Pool Configuration # Minimum number of workers per GraphEngine instance (default: 1) -GRAPH_ENGINE_MIN_WORKERS=1 +GRAPH_ENGINE_MIN_WORKERS=3 # Maximum number of workers per GraphEngine instance (default: 10) GRAPH_ENGINE_MAX_WORKERS=10 # Queue depth threshold that triggers worker scale up (default: 3) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index ccb97d96ef..a752d9d103 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -761,7 +761,7 @@ class WorkflowConfig(BaseSettings): # GraphEngine Worker Pool Configuration GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field( description="Minimum number of workers per GraphEngine instance", - default=1, + default=3, ) GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field( diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index f315d053cb..69ed4ae43b 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -100,6 +100,13 @@ class CheckDependenciesPendingData(BaseModel): class RagPipelineDslService: + """Import, export, and inspect RAG pipeline DSL using the caller-owned session. + + Controllers wrap this service in a SQLAlchemy transaction context, so methods must only flush interim changes when + generated IDs are needed. Committing inside the service would close the caller's transaction and break later work in + the same context manager. + """ + def __init__(self, session: Session): self._session = session @@ -325,7 +332,7 @@ class RagPipelineDslService: type=CollectionBindingType.DATASET, ) self._session.add(dataset_collection_binding) - self._session.commit() + self._session.flush() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model @@ -337,7 +344,7 @@ class RagPipelineDslService: dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) - self._session.commit() + self._session.flush() dataset_id = dataset.id if not dataset_id: raise ValueError("DSL is not valid, please check the Knowledge Index node.") @@ -462,7 +469,7 @@ class RagPipelineDslService: type=CollectionBindingType.DATASET, ) self._session.add(dataset_collection_binding) - self._session.commit() + self._session.flush() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model @@ -474,7 +481,7 @@ class RagPipelineDslService: dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) - self._session.commit() + self._session.flush() dataset_id = dataset.id if not dataset_id: raise ValueError("DSL is not valid, please check the Knowledge Index node.") @@ -585,7 +592,7 @@ class RagPipelineDslService: pipeline.id = str(uuid4()) self._session.add(pipeline) - self._session.commit() + self._session.flush() # save dependencies if dependencies: redis_client.setex( @@ -627,8 +634,8 @@ class RagPipelineDslService: workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables workflow.rag_pipeline_variables = rag_pipeline_variables_list - # commit db session changes - self._session.commit() + # Keep transaction ownership with the caller while materializing IDs and constraint checks before returning. + self._session.flush() return pipeline diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 7d23b63049..100b294f52 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -297,7 +297,7 @@ class TableTestRunner: max_workers: int = 4, enable_logging: bool = False, log_level: str = "INFO", - graph_engine_min_workers: int = 1, + graph_engine_min_workers: int = 3, graph_engine_max_workers: int = 1, graph_engine_scale_up_threshold: int = 5, graph_engine_scale_down_idle_time: float = 30.0, @@ -310,7 +310,7 @@ class TableTestRunner: max_workers: Maximum number of parallel workers for test execution enable_logging: Enable detailed logging log_level: Logging level (DEBUG, INFO, WARNING, ERROR) - graph_engine_min_workers: Minimum workers for GraphEngine (default: 1) + graph_engine_min_workers: Minimum workers for GraphEngine (default: 3) graph_engine_max_workers: Maximum workers for GraphEngine (default: 1) graph_engine_scale_up_threshold: Queue depth to trigger scale up graph_engine_scale_down_idle_time: Idle time before scaling down diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py index 337659b15f..2aea1285aa 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py @@ -259,6 +259,60 @@ workflow: if result.status == ImportStatus.FAILED: print(f"DEBUG: {result.error}") assert result.status == ImportStatus.COMPLETED + session.commit.assert_not_called() + session.flush.assert_called() + + +def test_import_rag_pipeline_flushes_new_collection_binding_without_commit(mocker) -> None: + yaml_content = """ +version: 0.1.0 +kind: rag_pipeline +rag_pipeline: + name: Test Pipeline +workflow: + graph: + nodes: + - data: + type: knowledge-index +""" + pipeline = Mock(id="p1", description="desc", is_published=False) + pipeline.name = "Test Pipeline" + mocker.patch.object(RagPipelineDslService, "_create_or_update_pipeline", return_value=pipeline) + + config_mock = Mock() + config_mock.indexing_technique = "high_quality" + config_mock.embedding_model = "m" + config_mock.embedding_model_provider = "p" + config_mock.chunk_structure = "text_model" + config_mock.retrieval_model.model_dump.return_value = {} + config_mock.summary_index_setting = None + mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.KnowledgeConfiguration.model_validate", + return_value=config_mock, + ) + + dataset_mock = Mock(id="d1") + binding_mock = Mock(id="b1") + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock) + binding_cls = mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding", + return_value=binding_mock, + ) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) + + session = cast(MagicMock, Mock()) + session.scalar.return_value = None + session.scalars.return_value.all.return_value = [] + service = RagPipelineDslService(session=cast(Session, session)) + account = Mock(current_tenant_id="t1") + + result = service.import_rag_pipeline(account=account, import_mode="yaml-content", yaml_content=yaml_content) + + assert result.status == ImportStatus.COMPLETED + binding_cls.assert_called_once() + assert dataset_mock.collection_binding_id == "b1" + session.commit.assert_not_called() + assert session.flush.call_count >= 2 def test_import_rag_pipeline_pending_version(mocker) -> None: @@ -338,6 +392,67 @@ workflow: assert result.dataset_id == "d1" +def test_confirm_import_flushes_new_collection_binding_without_commit(mocker) -> None: + from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelinePendingData + + yaml_content = """ +version: 0.1.0 +kind: rag_pipeline +rag_pipeline: + name: Test Pipeline +workflow: + graph: + nodes: + - data: + type: knowledge-index +""" + pending = RagPipelinePendingData(import_mode="yaml-content", yaml_content=yaml_content, pipeline_id="p1") + mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.redis_client.get", + return_value=pending.model_dump_json(), + ) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.redis_client.delete") + + pipeline = Mock(id="p1", description="desc") + pipeline.name = "Test Pipeline" + pipeline.retrieve_dataset.return_value = None + mocker.patch.object(RagPipelineDslService, "_create_or_update_pipeline", return_value=pipeline) + + config_mock = Mock() + config_mock.indexing_technique = "high_quality" + config_mock.embedding_model = "m" + config_mock.embedding_model_provider = "p" + config_mock.chunk_structure = "text_model" + config_mock.retrieval_model.model_dump.return_value = {} + config_mock.summary_index_setting = None + mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.KnowledgeConfiguration.model_validate", + return_value=config_mock, + ) + + dataset_mock = Mock(id="d1") + binding_mock = Mock(id="b1") + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock) + binding_cls = mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding", + return_value=binding_mock, + ) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) + + session = cast(MagicMock, Mock()) + session.scalar.side_effect = [pipeline, None] + service = RagPipelineDslService(session=cast(Session, session)) + account = Mock(id="u1", current_tenant_id="t1") + + result = service.confirm_import(account=account, import_id="imp-1") + + assert result.status == ImportStatus.COMPLETED + binding_cls.assert_called_once() + assert dataset_mock.collection_binding_id == "b1" + session.commit.assert_not_called() + assert session.flush.call_count >= 2 + + # --- _extract_dependencies_from_workflow_graph all types --- @@ -421,6 +536,8 @@ def test_create_or_update_pipeline_create_new(mocker) -> None: assert result == pipeline_instance session.add.assert_called() + session.commit.assert_not_called() + session.flush.assert_called() # --- export_rag_pipeline_dsl comprehensive --- diff --git a/docker/envs/core-services/shared.env.example b/docker/envs/core-services/shared.env.example index 80cfe42c38..fca0b57d0c 100644 --- a/docker/envs/core-services/shared.env.example +++ b/docker/envs/core-services/shared.env.example @@ -177,7 +177,7 @@ WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 MAX_VARIABLE_SIZE=204800 WORKFLOW_FILE_UPLOAD_LIMIT=10 -GRAPH_ENGINE_MIN_WORKERS=1 +GRAPH_ENGINE_MIN_WORKERS=3 GRAPH_ENGINE_MAX_WORKERS=10 GRAPH_ENGINE_SCALE_UP_THRESHOLD=3 GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0