test: improve unit tests for controllers.service_api (#32073)

Co-authored-by: Rajat Agarwal <rajat.agarwal@infocusp.com>
This commit is contained in:
Dev Sharma
2026-02-25 12:15:50 +05:30
committed by GitHub
parent 212756c315
commit d773096146
24 changed files with 11279 additions and 2 deletions

View File

@ -3,7 +3,8 @@ from typing import Any
from flask import request
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
@ -17,7 +18,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from libs import helper
from libs.login import current_user
from models import Account
from models.dataset import Pipeline
from models.dataset import Dataset, Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
@ -65,6 +66,12 @@ class DatasourcePluginsApi(DatasetApiResource):
)
def get(self, tenant_id: str, dataset_id: str):
"""Resource for getting datasource plugins."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
# Get query parameter to determine published or draft
is_published: bool = request.args.get("is_published", default=True, type=bool)
@ -104,6 +111,12 @@ class DatasourceNodeRunApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
@ -161,6 +174,12 @@ class PipelineRunApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
if not isinstance(current_user, Account):