Compare commits

..

2 Commits

Author SHA1 Message Date
41b52bbefb [autofix.ci] apply automated fixes 2026-06-29 20:05:12 +00:00
1f72e9799b fix(api): scope nested resource lookups by owner refs 2026-06-30 04:01:12 +08:00
36 changed files with 1070 additions and 427 deletions

View File

@ -3,6 +3,7 @@ from typing import Any, Literal
from flask_restx import Resource
from pydantic import BaseModel, Field, RootModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.fields import SimpleDataResponse
@ -216,7 +217,9 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = session.get(App, args.flow_id)
app = session.scalar(
select(App).where(App.id == args.flow_id, App.tenant_id == current_tenant_id).limit(1)
)
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app, session=session)

View File

@ -26,6 +26,7 @@ from libs.helper import to_timestamp
from libs.login import login_required
from models.enums import AppMCPServerStatus
from models.model import App, AppMCPServer
from services.app_ref_service import AppRefService
class MCPServerCreatePayload(BaseModel):
@ -146,7 +147,17 @@ class AppMCPServerController(Resource):
@get_app_model
def put(self, app_model: App):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.get(AppMCPServer, payload.id)
app_ref = AppRefService.create_app_ref(app_model)
server_ref = AppRefService.create_mcp_server_ref(app_ref, payload.id)
server = db.session.scalar(
select(AppMCPServer)
.where(
AppMCPServer.id == server_ref.server_id,
AppMCPServer.tenant_id == server_ref.tenant_id,
AppMCPServer.app_id == server_ref.app_id,
)
.limit(1)
)
if not server:
raise NotFound()

View File

@ -78,6 +78,7 @@ from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_ref_service import WorkflowRefService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
@ -1406,6 +1407,7 @@ class WorkflowByIdApi(Resource):
return {"message": "No valid fields to update"}, 400
workflow_service = WorkflowService()
workflow_ref = WorkflowRefService.create_app_workflow_ref(app_model, workflow_id)
# Create a session and manage the transaction
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
@ -1415,6 +1417,7 @@ class WorkflowByIdApi(Resource):
tenant_id=app_model.tenant_id,
account_id=current_user.id,
data=update_data,
workflow_ref=workflow_ref,
)
if not workflow:
@ -1434,12 +1437,16 @@ class WorkflowByIdApi(Resource):
Delete workflow
"""
workflow_service = WorkflowService()
workflow_ref = WorkflowRefService.create_app_workflow_ref(app_model, workflow_id)
# Create a session and manage the transaction
with sessionmaker(db.engine).begin() as session:
try:
workflow_service.delete_workflow(
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
session=session,
workflow_id=workflow_id,
tenant_id=app_model.tenant_id,
workflow_ref=workflow_ref,
)
except WorkflowInUseError as e:
abort(400, description=str(e))

View File

@ -58,8 +58,9 @@ from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response, escape_like_pattern
from libs.login import login_required
from models import Account
from models.dataset import ChildChunk, DocumentSegment
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile
from services.dataset_ref_service import DatasetRefService, SegmentRef
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
@ -162,6 +163,21 @@ register_response_schema_models(
)
def _get_segment_for_document(
dataset: Dataset, document: Document, segment_id: str
) -> tuple[SegmentRef, DocumentSegment]:
dataset_ref = DatasetRefService.create_dataset_ref(dataset)
document_ref = DatasetRefService.create_document_ref(dataset_ref, document)
if document_ref is None:
raise NotFound("Document not found.")
segment_ref = DatasetRefService.create_segment_ref(document_ref, segment_id)
segment = SegmentService.get_segment_by_ref(segment_ref)
if not segment:
raise NotFound("Segment not found.")
return segment_ref, segment
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@ -465,6 +481,13 @@ class DatasetDocumentSegmentUpdateApi(Resource):
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user, db.session)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# check embedding model setting
try:
@ -481,22 +504,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id_str = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user, db.session)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
# validate args
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
@ -537,15 +546,6 @@ class DatasetDocumentSegmentUpdateApi(Resource):
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
@ -553,6 +553,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user, db.session)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
segment_id_str = str(segment_id)
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
SegmentService.delete_segment(segment, document, dataset)
return "", 204
@ -659,17 +661,12 @@ class ChildChunkAddApi(Resource):
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user, db.session)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# check embedding model setting
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
try:
@ -686,10 +683,8 @@ class ChildChunkAddApi(Resource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
DatasetService.check_dataset_permission(dataset, current_user, db.session)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
segment_id_str = str(segment_id)
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
# validate args
try:
payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
@ -719,15 +714,8 @@ class ChildChunkAddApi(Resource):
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
_get_segment_for_document(dataset, document, segment_id_str)
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
page = args.page
@ -776,15 +764,6 @@ class ChildChunkAddApi(Resource):
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
@ -792,6 +771,8 @@ class ChildChunkAddApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user, db.session)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
segment_id_str = str(segment_id)
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
# validate args
payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
try:
@ -835,29 +816,6 @@ class ChildChunkUpdateApi(Resource):
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id_str = str(child_chunk_id)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == child_chunk_id_str,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
)
.limit(1)
)
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
@ -865,6 +823,12 @@ class ChildChunkUpdateApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user, db.session)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
segment_id_str = str(segment_id)
segment_ref, _ = _get_segment_for_document(dataset, document, segment_id_str)
child_chunk_id_str = str(child_chunk_id)
child_chunk = SegmentService.get_child_chunk_by_segment_ref(child_chunk_id_str, segment_ref)
if not child_chunk:
raise NotFound("Child chunk not found.")
try:
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
@ -903,29 +867,6 @@ class ChildChunkUpdateApi(Resource):
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id_str = str(child_chunk_id)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == child_chunk_id_str,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
)
.limit(1)
)
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
@ -933,6 +874,12 @@ class ChildChunkUpdateApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user, db.session)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
segment_id_str = str(segment_id)
segment_ref, segment = _get_segment_for_document(dataset, document, segment_id_str)
child_chunk_id_str = str(child_chunk_id)
child_chunk = SegmentService.get_child_chunk_by_segment_ref(child_chunk_id_str, segment_ref)
if not child_chunk:
raise NotFound("Child chunk not found.")
# validate args
try:
payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})

View File

@ -64,6 +64,7 @@ from services.rag_pipeline.pipeline_generate_service import PipelineGenerateServ
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTransformService
from services.workflow_ref_service import WorkflowRefService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
@ -738,6 +739,7 @@ class RagPipelineByIdApi(Resource):
return {"message": "No valid fields to update"}, 400
rag_pipeline_service = RagPipelineService()
workflow_ref = WorkflowRefService.create_pipeline_workflow_ref(pipeline, workflow_id)
# Create a session and manage the transaction
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
@ -747,6 +749,7 @@ class RagPipelineByIdApi(Resource):
tenant_id=pipeline.tenant_id,
account_id=current_user.id,
data=update_data,
workflow_ref=workflow_ref,
)
if not workflow:
@ -769,6 +772,7 @@ class RagPipelineByIdApi(Resource):
abort(400, description=f"Cannot delete workflow that is currently in use by pipeline '{pipeline.id}'")
workflow_service = WorkflowService()
workflow_ref = WorkflowRefService.create_pipeline_workflow_ref(pipeline, workflow_id)
with sessionmaker(db.engine).begin() as session:
try:
@ -776,6 +780,7 @@ class RagPipelineByIdApi(Resource):
session=session,
workflow_id=workflow_id,
tenant_id=pipeline.tenant_id,
workflow_ref=workflow_ref,
)
except WorkflowInUseError as e:
abort(400, description=str(e))

View File

@ -425,6 +425,7 @@ class TrialChatTextApi(TrialAppResource):
text=text,
voice=voice,
message_id=message_id,
message_account_id=current_user.id,
)
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
return response

View File

@ -117,7 +117,8 @@ def _enforce_snippet_tag_rbac_by_tag_id(tag_id: str) -> None:
if not dify_config.RBAC_ENABLED:
return
tag_type = db.session.scalar(select(Tag.type).where(Tag.id == tag_id).limit(1))
_, current_tenant_id = current_account_with_tenant()
tag_type = db.session.scalar(select(Tag.type).where(Tag.id == tag_id, Tag.tenant_id == current_tenant_id).limit(1))
_enforce_snippet_tag_rbac_if_needed(tag_type)

View File

@ -184,6 +184,7 @@ class TextApi(Resource):
voice=voice,
end_user=end_user.external_user_id,
message_id=message_id,
message_end_user_id=end_user.id,
)
return response

View File

@ -939,9 +939,11 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
tag_id = payload.tag_id
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id, db.session)
tag = TagService.update_tags(
UpdateTagServicePayload(name=payload.name), tag_id, db.session, tag_type=TagType.KNOWLEDGE
)
binding_count = TagService.get_tag_binding_count(tag_id, db.session)
binding_count = TagService.get_tag_binding_count(tag_id, db.session, tag_type=TagType.KNOWLEDGE)
response = dump_response(
KnowledgeTagResponse,
@ -971,7 +973,7 @@ class DatasetTagsApi(DatasetApiResource):
def delete(self, _):
"""Delete a knowledge type tag."""
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id, db.session)
TagService.delete_tag(payload.tag_id, db.session, tag_type=TagType.KNOWLEDGE)
return "", 204

View File

@ -37,7 +37,8 @@ from fields.segment_fields import (
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response
from libs.login import current_account_with_tenant
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_ref_service import DatasetRefService, SegmentRef
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
@ -127,6 +128,21 @@ register_response_schema_models(
)
def _get_segment_for_document(
dataset: Dataset, document: Document, segment_id: str
) -> tuple[SegmentRef, DocumentSegment]:
dataset_ref = DatasetRefService.create_dataset_ref(dataset)
document_ref = DatasetRefService.create_document_ref(dataset_ref, document)
if document_ref is None:
raise NotFound("Document not found.")
segment_ref = DatasetRefService.create_segment_ref(document_ref, segment_id)
segment = SegmentService.get_segment_by_ref(segment_ref)
if not segment:
raise NotFound("Segment not found.")
return segment_ref, segment
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
@ -337,7 +353,7 @@ class DatasetSegmentApi(DatasetApiResource):
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
@ -353,10 +369,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
SegmentService.delete_segment(segment, document, dataset)
return "", 204
@ -415,10 +428,7 @@ class DatasetSegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
@ -457,7 +467,7 @@ class DatasetSegmentApi(DatasetApiResource):
service_api_ns.models[SegmentDetailResponse.__name__],
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
@ -473,10 +483,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
response = {
@ -538,10 +545,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
# check embedding model setting
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
@ -595,7 +599,7 @@ class ChildChunkApi(DatasetApiResource):
service_api_ns.models[ChildChunkListResponse.__name__],
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
current_account_with_tenant()
"""Get child chunks."""
dataset_id_str = str(dataset_id)
# check dataset
@ -612,10 +616,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
_get_segment_for_document(dataset, document, segment_id_str)
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
@ -665,7 +666,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
_, current_tenant_id = current_account_with_tenant()
current_account_with_tenant()
"""Delete child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
@ -682,27 +683,14 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if segment.document_id != document_id_str:
raise NotFound("Document not found.")
segment_ref, _ = _get_segment_for_document(dataset, document, segment_id_str)
child_chunk_id_str = str(child_chunk_id)
# check child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
)
child_chunk = SegmentService.get_child_chunk_by_segment_ref(child_chunk_id_str, segment_ref)
if not child_chunk:
raise NotFound("Child chunk not found.")
# validate child chunk belongs to the specified segment
if child_chunk.segment_id != segment.id:
raise NotFound("Child chunk not found.")
try:
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
@ -739,7 +727,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
_, current_tenant_id = current_account_with_tenant()
current_account_with_tenant()
"""Update child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
@ -756,27 +744,14 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# get segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if segment.document_id != document_id_str:
raise NotFound("Segment not found.")
segment_ref, segment = _get_segment_for_document(dataset, document, segment_id_str)
child_chunk_id_str = str(child_chunk_id)
# get child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
)
child_chunk = SegmentService.get_child_chunk_by_segment_ref(child_chunk_id_str, segment_ref)
if not child_chunk:
raise NotFound("Child chunk not found.")
# validate child chunk belongs to the specified segment
if child_chunk.segment_id != segment.id:
raise NotFound("Child chunk not found.")
# validate args
payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {})

View File

@ -137,6 +137,7 @@ class TextApi(WebApiResource):
voice=voice,
end_user=end_user.external_user_id,
message_id=message_id,
message_end_user_id=end_user.id,
)
return response

View File

@ -498,7 +498,11 @@ class LLMGenerator:
ideal_output: str | None,
):
last_run: Message | None = db.session.scalar(
select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1)
select(Message)
.join(App, App.id == Message.app_id)
.where(Message.app_id == flow_id, App.tenant_id == tenant_id)
.order_by(Message.created_at.desc())
.limit(1)
)
if not last_run:
return LLMGenerator.__instruction_modify_common(
@ -540,7 +544,7 @@ class LLMGenerator:
):
session = db.session()
app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
app: App | None = session.scalar(select(App).where(App.id == flow_id, App.tenant_id == tenant_id).limit(1))
if not app:
raise ValueError("App not found.")
workflow = workflow_service.get_draft_workflow(app_model=app)

View File

@ -45,52 +45,34 @@ def upgrade():
# PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be
# generated and controlled within the application layer.
conn = op.get_bind()
if _is_pg(conn):
# PostgreSQL: Create uuidv7 functions.
# PostgreSQL 18 ships a native pg_catalog.uuidv7(), so only create our own
# implementation when the server does not already provide one. Otherwise the
# CREATE FUNCTION below and the unqualified COMMENT statement collide with the
# built-in and the migration fails.
#
# The existence check is done server-side via a DO block rather than
# conn.execute().scalar() because the latter returns None in offline
# migration mode (no real database connection), causing an AttributeError.
# PostgreSQL: Create uuidv7 functions
op.execute(sa.text(r"""
DO $do$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_proc p
JOIN pg_namespace n ON p.pronamespace = n.oid
WHERE p.proname = 'uuidv7' AND n.nspname = 'pg_catalog'
) THEN
/* Main function to generate a uuidv7 value with millisecond precision */
CREATE FUNCTION public.uuidv7() RETURNS uuid
AS
$func$
-- Replace the first 48 bits of a uuidv4 with the current
-- number of milliseconds since 1970-01-01 UTC
-- and set the "ver" field to 7 by setting additional bits
SELECT encode(
/* Main function to generate a uuidv7 value with millisecond precision */
CREATE FUNCTION uuidv7() RETURNS uuid
AS
$$
-- Replace the first 48 bits of a uuidv4 with the current
-- number of milliseconds since 1970-01-01 UTC
-- and set the "ver" field to 7 by setting additional bits
SELECT encode(
set_bit(
set_bit(
set_bit(
overlay(uuid_send(gen_random_uuid()) placing
substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from
3)
from 1 for 6),
52, 1),
53, 1), 'hex')::uuid;
$func$ LANGUAGE SQL VOLATILE PARALLEL SAFE;
overlay(uuid_send(gen_random_uuid()) placing
substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from
3)
from 1 for 6),
52, 1),
53, 1), 'hex')::uuid;
$$ LANGUAGE SQL VOLATILE PARALLEL SAFE;
COMMENT ON FUNCTION public.uuidv7 IS
'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';
END IF;
END
$do$;
COMMENT ON FUNCTION uuidv7 IS
'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';
"""))
op.execute(sa.text(r"""
CREATE FUNCTION public.uuidv7_boundary(timestamptz) RETURNS uuid
CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid
AS
$$
/* uuid fields: version=0b0111, variant=0b10 */
@ -101,7 +83,7 @@ SELECT encode(
'hex')::uuid;
$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE;
COMMENT ON FUNCTION public.uuidv7_boundary(timestamptz) IS
COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS
'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.';
"""
))
@ -113,10 +95,7 @@ def downgrade():
conn = op.get_bind()
if _is_pg(conn):
# IF EXISTS keeps the downgrade a no-op on PostgreSQL 18, where the native
# pg_catalog.uuidv7() was kept and no public.uuidv7() was created. Scoping the
# drop to the public schema avoids touching the built-in.
op.execute(sa.text("DROP FUNCTION IF EXISTS public.uuidv7()"))
op.execute(sa.text("DROP FUNCTION IF EXISTS public.uuidv7_boundary(timestamptz)"))
op.execute(sa.text("DROP FUNCTION uuidv7"))
op.execute(sa.text("DROP FUNCTION uuidv7_boundary"))
else:
pass

View File

@ -14,6 +14,7 @@ from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.app_ref_service import AnnotationRef, AppRefService
from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
@ -88,6 +89,17 @@ class UpdateAnnotationSettingArgs(TypedDict):
class AppAnnotationService:
@staticmethod
def _get_annotation_by_ref(annotation_ref: AnnotationRef, session: scoped_session) -> MessageAnnotation | None:
return session.scalar(
select(MessageAnnotation)
.where(
MessageAnnotation.id == annotation_ref.annotation_id,
MessageAnnotation.app_id == annotation_ref.app_id,
)
.limit(1)
)
@classmethod
def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation:
# get app info
@ -313,7 +325,9 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
annotation = session.get(MessageAnnotation, annotation_id)
app_ref = AppRefService.create_app_ref(app)
annotation_ref = AppRefService.create_annotation_ref(app_ref, annotation_id)
annotation = cls._get_annotation_by_ref(annotation_ref, session)
if not annotation:
raise NotFound("Annotation not found")
@ -357,7 +371,9 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
annotation = session.get(MessageAnnotation, annotation_id)
app_ref = AppRefService.create_app_ref(app)
annotation_ref = AppRefService.create_annotation_ref(app_ref, annotation_id)
annotation = cls._get_annotation_by_ref(annotation_ref, session)
if not annotation:
raise NotFound("Annotation not found")
@ -393,11 +409,13 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
app_ref = AppRefService.create_app_ref(app)
# Fetch annotations and their settings in a single query
annotations_to_delete = db.session.execute(
select(MessageAnnotation, AppAnnotationSetting)
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
.where(MessageAnnotation.id.in_(annotation_ids))
.where(MessageAnnotation.id.in_(annotation_ids), MessageAnnotation.app_id == app_ref.app_id)
).all()
if not annotations_to_delete:
@ -420,7 +438,10 @@ class AppAnnotationService:
# Step 4: Bulk delete annotations in a single query
delete_result = db.session.execute(
delete(MessageAnnotation).where(MessageAnnotation.id.in_(annotation_ids_to_delete))
delete(MessageAnnotation).where(
MessageAnnotation.id.in_(annotation_ids_to_delete),
MessageAnnotation.app_id == app_ref.app_id,
)
)
deleted_count = getattr(delete_result, "rowcount", 0)
@ -572,7 +593,9 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
annotation = db.session.get(MessageAnnotation, annotation_id)
app_ref = AppRefService.create_app_ref(app)
annotation_ref = AppRefService.create_annotation_ref(app_ref, annotation_id)
annotation = cls._get_annotation_by_ref(annotation_ref, db.session)
if not annotation:
raise NotFound("Annotation not found")

View File

@ -0,0 +1,100 @@
"""Typed resource references for app ownership chains."""
from typing import NamedTuple
from models.model import App
_APP_REF_CTOR_TOKEN = object()
class _AppRefBase(NamedTuple):
tenant_id: str
app_id: str
ctor_token: object
class AppRef(_AppRefBase):
"""Tenant-scoped app reference with token-gated construction."""
__slots__ = ()
def __new__(cls, tenant_id: str, app_id: str, ctor_token: object) -> "AppRef":
if ctor_token is not _APP_REF_CTOR_TOKEN:
raise ValueError("AppRef must be created by AppRefService.")
return super().__new__(cls, tenant_id, app_id, ctor_token)
def __repr__(self) -> str:
return f"AppRef(tenant_id={self.tenant_id!r}, app_id={self.app_id!r})"
class MessageRef(NamedTuple):
"""Message reference bound to a trusted app reference."""
app: AppRef
message_id: str
end_user_id: str | None = None
account_id: str | None = None
@property
def tenant_id(self) -> str:
return self.app.tenant_id
@property
def app_id(self) -> str:
return self.app.app_id
class AnnotationRef(NamedTuple):
"""Annotation reference bound to a trusted app reference."""
app: AppRef
annotation_id: str
@property
def tenant_id(self) -> str:
return self.app.tenant_id
@property
def app_id(self) -> str:
return self.app.app_id
class AppMCPServerRef(NamedTuple):
"""MCP server reference bound to a trusted app reference."""
app: AppRef
server_id: str
@property
def tenant_id(self) -> str:
return self.app.tenant_id
@property
def app_id(self) -> str:
return self.app.app_id
class AppRefService:
"""Factory for trusted app and child resource refs."""
@staticmethod
def create_app_ref(app: App) -> AppRef:
return AppRef(app.tenant_id, app.id, _APP_REF_CTOR_TOKEN)
@staticmethod
def create_message_ref(
app_ref: AppRef,
message_id: str,
*,
end_user_id: str | None = None,
account_id: str | None = None,
) -> MessageRef:
return MessageRef(app=app_ref, message_id=message_id, end_user_id=end_user_id, account_id=account_id)
@staticmethod
def create_annotation_ref(app_ref: AppRef, annotation_id: str) -> AnnotationRef:
return AnnotationRef(app=app_ref, annotation_id=annotation_id)
@staticmethod
def create_mcp_server_ref(app_ref: AppRef, server_id: str) -> AppMCPServerRef:
return AppMCPServerRef(app=app_ref, server_id=server_id)

View File

@ -5,6 +5,7 @@ from collections.abc import Generator
from typing import cast
from flask import Response, stream_with_context
from sqlalchemy import select
from sqlalchemy.orm import Session, scoped_session
from werkzeug.datastructures import FileStorage
@ -13,6 +14,7 @@ from core.model_manager import ModelManager
from graphon.model_runtime.entities.model_entities import ModelType
from models.enums import MessageStatus
from models.model import App, AppMode, Message
from services.app_ref_service import AppRefService, MessageRef
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
@ -29,6 +31,15 @@ logger = logging.getLogger(__name__)
class AudioService:
@staticmethod
def _get_message_by_ref(session: Session | scoped_session, message_ref: MessageRef) -> Message | None:
stmt = select(Message).where(Message.id == message_ref.message_id, Message.app_id == message_ref.app_id)
if message_ref.end_user_id is not None:
stmt = stmt.where(Message.from_end_user_id == message_ref.end_user_id)
if message_ref.account_id is not None:
stmt = stmt.where(Message.from_account_id == message_ref.account_id)
return session.scalar(stmt.limit(1))
@classmethod
def transcript_asr(cls, app_model: App, file: FileStorage | None, end_user: str | None = None):
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
@ -83,6 +94,8 @@ class AudioService:
voice: str | None = None,
end_user: str | None = None,
message_id: str | None = None,
message_end_user_id: str | None = None,
message_account_id: str | None = None,
is_draft: bool = False,
):
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
@ -134,7 +147,14 @@ class AudioService:
uuid.UUID(message_id)
except ValueError:
return None
message = session.get(Message, message_id)
app_ref = AppRefService.create_app_ref(app_model)
message_ref = AppRefService.create_message_ref(
app_ref,
message_id,
end_user_id=message_end_user_id,
account_id=message_account_id,
)
message = cls._get_message_by_ref(session, message_ref)
if message is None:
return None
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:

View File

@ -0,0 +1,85 @@
"""Typed resource references for dataset ownership chains.
Controllers and other trust-boundary code should build these refs after
resolving and authorizing the outer resource. Downstream service queries can
then require the full tenant -> dataset -> document -> segment chain instead of
receiving loosely related raw ids.
"""
from typing import NamedTuple
from models.dataset import Dataset, Document
_DATASET_REF_CTOR_TOKEN = object()
class _DatasetRefBase(NamedTuple):
tenant_id: str
dataset_id: str
ctor_token: object
class DatasetRef(_DatasetRefBase):
"""Tenant-scoped dataset reference with token-gated construction."""
__slots__ = ()
def __new__(cls, tenant_id: str, dataset_id: str, ctor_token: object) -> "DatasetRef":
if ctor_token is not _DATASET_REF_CTOR_TOKEN:
raise ValueError("DatasetRef must be created by DatasetRefService.")
return super().__new__(cls, tenant_id, dataset_id, ctor_token)
def __repr__(self) -> str:
return f"DatasetRef(tenant_id={self.tenant_id!r}, dataset_id={self.dataset_id!r})"
class DocumentRef(NamedTuple):
"""Document reference bound to a trusted dataset reference."""
dataset: DatasetRef
document_id: str
@property
def tenant_id(self) -> str:
return self.dataset.tenant_id
@property
def dataset_id(self) -> str:
return self.dataset.dataset_id
class SegmentRef(NamedTuple):
"""Segment reference bound to a trusted document reference."""
document: DocumentRef
segment_id: str
@property
def tenant_id(self) -> str:
return self.document.tenant_id
@property
def dataset_id(self) -> str:
return self.document.dataset_id
@property
def document_id(self) -> str:
return self.document.document_id
class DatasetRefService:
"""Factory for trusted dataset, document, and segment refs."""
@staticmethod
def create_dataset_ref(dataset: Dataset) -> DatasetRef:
return DatasetRef(dataset.tenant_id, dataset.id, _DATASET_REF_CTOR_TOKEN)
@staticmethod
def create_document_ref(dataset_ref: DatasetRef, document: Document) -> DocumentRef | None:
if document.tenant_id != dataset_ref.tenant_id or document.dataset_id != dataset_ref.dataset_id:
return None
return DocumentRef(dataset=dataset_ref, document_id=document.id)
@staticmethod
def create_segment_ref(document_ref: DocumentRef, segment_id: str) -> SegmentRef:
return SegmentRef(document=document_ref, segment_id=segment_id)

View File

@ -65,6 +65,7 @@ from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
from models.workflow import Workflow
from services.dataset_ref_service import DatasetRefService, SegmentRef
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy
from services.enterprise import rbac_service as enterprise_rbac_service
@ -1926,7 +1927,15 @@ class DocumentService:
# Check if document_ids is not empty to avoid WHERE false condition
if not document_ids or len(document_ids) == 0:
return
documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all()
dataset_ref = DatasetRefService.create_dataset_ref(dataset)
documents = db.session.scalars(
select(Document).where(
Document.id.in_(document_ids),
Document.tenant_id == dataset_ref.tenant_id,
Document.dataset_id == dataset_ref.dataset_id,
)
).all()
deleted_document_ids = [document.id for document in documents]
file_ids = [
document.data_source_info_dict.get("upload_file_id", "")
for document in documents
@ -1941,8 +1950,8 @@ class DocumentService:
# Dispatch cleanup task after commit to avoid lock contention
# Task cleans up segments, files, and vector indexes
if dataset.doc_form is not None:
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
if deleted_document_ids and dataset.doc_form is not None:
batch_clean_document_task.delay(deleted_document_ids, dataset.id, dataset.doc_form, file_ids)
@staticmethod
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
@ -4028,6 +4037,22 @@ class SegmentService:
)
return result if isinstance(result, ChildChunk) else None
@classmethod
def get_child_chunk_by_segment_ref(cls, child_chunk_id: str, segment_ref: SegmentRef) -> ChildChunk | None:
"""Get a child chunk through the full tenant/dataset/document/segment chain."""
result = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == child_chunk_id,
ChildChunk.tenant_id == segment_ref.tenant_id,
ChildChunk.dataset_id == segment_ref.dataset_id,
ChildChunk.document_id == segment_ref.document_id,
ChildChunk.segment_id == segment_ref.segment_id,
)
.limit(1)
)
return result if isinstance(result, ChildChunk) else None
@classmethod
def get_segments(
cls,
@ -4066,6 +4091,21 @@ class SegmentService:
)
return result if isinstance(result, DocumentSegment) else None
@classmethod
def get_segment_by_ref(cls, segment_ref: SegmentRef) -> DocumentSegment | None:
"""Get a segment through the full tenant/dataset/document ownership chain."""
result = db.session.scalar(
select(DocumentSegment)
.where(
DocumentSegment.id == segment_ref.segment_id,
DocumentSegment.tenant_id == segment_ref.tenant_id,
DocumentSegment.dataset_id == segment_ref.dataset_id,
DocumentSegment.document_id == segment_ref.document_id,
)
.limit(1)
)
return result if isinstance(result, DocumentSegment) else None
@classmethod
def get_segments_by_document_and_dataset(
cls,

View File

@ -82,6 +82,7 @@ from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError,
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader
from services.workflow_ref_service import WorkflowRef
from services.workflow_restore import apply_published_workflow_snapshot_to_draft
logger = logging.getLogger(__name__)
@ -984,8 +985,21 @@ class RagPipelineService:
if invoke_from:
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
if document_id:
document = db.session.get(Document, document_id.value)
dataset_id = get_system_segment(variable_pool, SystemVariableKey.DATASET_ID)
pipeline_id = get_system_segment(variable_pool, SystemVariableKey.APP_ID)
if document_id and dataset_id and pipeline_id:
document = db.session.scalar(
select(Document)
.join(Dataset, Dataset.id == Document.dataset_id)
.where(
Document.id == document_id.value,
Document.tenant_id == tenant_id,
Document.dataset_id == dataset_id.value,
Dataset.tenant_id == tenant_id,
Dataset.pipeline_id == pipeline_id.value,
)
.limit(1)
)
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = error
@ -995,7 +1009,14 @@ class RagPipelineService:
return workflow_node_execution
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any]
self,
*,
session: Session,
workflow_id: str,
tenant_id: str,
account_id: str,
data: dict[str, Any],
workflow_ref: WorkflowRef | None = None,
) -> Workflow | None:
"""
Update workflow attributes
@ -1005,9 +1026,14 @@ class RagPipelineService:
:param tenant_id: Tenant ID
:param account_id: Account ID (for permission check)
:param data: Dictionary containing fields to update
:param workflow_ref: Optional trusted owner-bound workflow reference
:return: Updated workflow or None if not found
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
lookup_workflow_id = workflow_ref.workflow_id if workflow_ref else workflow_id
lookup_tenant_id = workflow_ref.tenant_id if workflow_ref else tenant_id
stmt = select(Workflow).where(Workflow.id == lookup_workflow_id, Workflow.tenant_id == lookup_tenant_id)
if workflow_ref is not None:
stmt = stmt.where(Workflow.app_id == workflow_ref.app_id)
workflow = session.scalar(stmt)
if not workflow:

View File

@ -147,8 +147,14 @@ class TagService:
return tag
@staticmethod
def update_tags(payload: UpdateTagPayload, tag_id: str, session: scoped_session) -> Tag:
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
def update_tags(
payload: UpdateTagPayload, tag_id: str, session: scoped_session, *, tag_type: TagType | None = None
) -> Tag:
current_tenant_id = current_user.current_tenant_id
stmt = select(Tag).where(Tag.id == tag_id, Tag.tenant_id == current_tenant_id)
if tag_type is not None:
stmt = stmt.where(Tag.type == tag_type)
tag = session.scalar(stmt.limit(1))
if not tag:
raise NotFound("Tag not found")
if payload.name != tag.name:
@ -169,18 +175,32 @@ class TagService:
return tag
@staticmethod
def get_tag_binding_count(tag_id: str, session: scoped_session) -> int:
count = session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
def get_tag_binding_count(tag_id: str, session: scoped_session, *, tag_type: TagType | None = None) -> int:
current_tenant_id = current_user.current_tenant_id
stmt = (
select(func.count(TagBinding.id))
.join(Tag, Tag.id == TagBinding.tag_id)
.where(TagBinding.tag_id == tag_id, Tag.tenant_id == current_tenant_id)
)
if tag_type is not None:
stmt = stmt.where(Tag.type == tag_type)
count = session.scalar(stmt) or 0
return count
@staticmethod
def delete_tag(tag_id: str, session: scoped_session):
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
def delete_tag(tag_id: str, session: scoped_session, *, tag_type: TagType | None = None):
current_tenant_id = current_user.current_tenant_id
stmt = select(Tag).where(Tag.id == tag_id, Tag.tenant_id == current_tenant_id)
if tag_type is not None:
stmt = stmt.where(Tag.type == tag_type)
tag = session.scalar(stmt.limit(1))
if not tag:
raise NotFound("Tag not found")
session.delete(tag)
# delete tag binding
tag_bindings = session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
tag_bindings = session.scalars(
select(TagBinding).where(TagBinding.tag_id == tag_id, TagBinding.tenant_id == current_tenant_id)
).all()
if tag_bindings:
for tag_binding in tag_bindings:
session.delete(tag_binding)

View File

@ -0,0 +1,57 @@
"""Typed resource references for workflow ownership chains."""
from typing import NamedTuple
from models.dataset import Pipeline
from models.model import App
_WORKFLOW_OWNER_REF_CTOR_TOKEN = object()
class _WorkflowOwnerRefBase(NamedTuple):
tenant_id: str
owner_id: str
ctor_token: object
class WorkflowOwnerRef(_WorkflowOwnerRefBase):
"""Tenant-scoped workflow owner reference with token-gated construction."""
__slots__ = ()
def __new__(cls, tenant_id: str, owner_id: str, ctor_token: object) -> "WorkflowOwnerRef":
if ctor_token is not _WORKFLOW_OWNER_REF_CTOR_TOKEN:
raise ValueError("WorkflowOwnerRef must be created by WorkflowRefService.")
return super().__new__(cls, tenant_id, owner_id, ctor_token)
def __repr__(self) -> str:
return f"WorkflowOwnerRef(tenant_id={self.tenant_id!r}, owner_id={self.owner_id!r})"
class WorkflowRef(NamedTuple):
"""Workflow reference bound to a trusted owner reference."""
owner: WorkflowOwnerRef
workflow_id: str
@property
def tenant_id(self) -> str:
return self.owner.tenant_id
@property
def app_id(self) -> str:
return self.owner.owner_id
class WorkflowRefService:
"""Factory for trusted app and RAG pipeline workflow refs."""
@staticmethod
def create_app_workflow_ref(app: App, workflow_id: str) -> WorkflowRef:
owner = WorkflowOwnerRef(app.tenant_id, app.id, _WORKFLOW_OWNER_REF_CTOR_TOKEN)
return WorkflowRef(owner=owner, workflow_id=workflow_id)
@staticmethod
def create_pipeline_workflow_ref(pipeline: Pipeline, workflow_id: str) -> WorkflowRef:
owner = WorkflowOwnerRef(pipeline.tenant_id, pipeline.id, _WORKFLOW_OWNER_REF_CTOR_TOKEN)
return WorkflowRef(owner=owner, workflow_id=workflow_id)

View File

@ -84,6 +84,7 @@ from services.errors.app import (
)
from services.human_input_service import HumanInputService
from services.workflow.workflow_converter import WorkflowConverter
from services.workflow_ref_service import WorkflowRef
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
from .human_input_delivery_test_service import (
@ -1593,7 +1594,14 @@ class WorkflowService:
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any]
self,
*,
session: Session,
workflow_id: str,
tenant_id: str,
account_id: str,
data: dict[str, Any],
workflow_ref: WorkflowRef | None = None,
) -> Workflow | None:
"""
Update workflow attributes
@ -1603,9 +1611,14 @@ class WorkflowService:
:param tenant_id: Tenant ID
:param account_id: Account ID (for permission check)
:param data: Dictionary containing fields to update
:param workflow_ref: Optional trusted owner-bound workflow reference
:return: Updated workflow or None if not found
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
lookup_workflow_id = workflow_ref.workflow_id if workflow_ref else workflow_id
lookup_tenant_id = workflow_ref.tenant_id if workflow_ref else tenant_id
stmt = select(Workflow).where(Workflow.id == lookup_workflow_id, Workflow.tenant_id == lookup_tenant_id)
if workflow_ref is not None:
stmt = stmt.where(Workflow.app_id == workflow_ref.app_id)
workflow = session.scalar(stmt)
if not workflow:
@ -1622,30 +1635,37 @@ class WorkflowService:
return workflow
def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool:
def delete_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, workflow_ref: WorkflowRef | None = None
) -> bool:
"""
Delete a workflow
:param session: SQLAlchemy database session
:param workflow_id: Workflow ID
:param tenant_id: Tenant ID
:param workflow_ref: Optional trusted owner-bound workflow reference
:return: True if successful
:raises: ValueError if workflow not found
:raises: WorkflowInUseError if workflow is in use
:raises: DraftWorkflowDeletionError if workflow is a draft version
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
lookup_workflow_id = workflow_ref.workflow_id if workflow_ref else workflow_id
lookup_tenant_id = workflow_ref.tenant_id if workflow_ref else tenant_id
stmt = select(Workflow).where(Workflow.id == lookup_workflow_id, Workflow.tenant_id == lookup_tenant_id)
if workflow_ref is not None:
stmt = stmt.where(Workflow.app_id == workflow_ref.app_id)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")
raise ValueError(f"Workflow with ID {lookup_workflow_id} not found")
# Check if workflow is a draft version
if workflow.version == Workflow.VERSION_DRAFT:
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
# Check if this workflow is currently referenced by an app
app_stmt = select(App).where(App.workflow_id == workflow_id)
app_stmt = select(App).where(App.workflow_id == lookup_workflow_id)
app = session.scalar(app_stmt)
if app:
# Cannot delete a workflow that's currently in use by an app

View File

@ -70,7 +70,7 @@ def test_instruction_generate_app_not_found(app: Flask, monkeypatch: pytest.Monk
method = unwrap(api.post)
session = MagicMock()
session.get.return_value = None
session.scalar.return_value = None
with app.test_request_context(
"/console/api/instruction-generate",
@ -86,7 +86,13 @@ def test_instruction_generate_app_not_found(app: Flask, monkeypatch: pytest.Monk
assert status == 400
assert response["error"] == "app app-1 not found"
session.get.assert_called_once_with(generator_module.App, "app-1")
stmt = session.scalar.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "apps.id" in statement
assert "apps.tenant_id" in statement
assert "app-1" in compiled.params.values()
assert "t1" in compiled.params.values()
def test_instruction_generate_workflow_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
@ -94,7 +100,7 @@ def test_instruction_generate_workflow_not_found(app: Flask, monkeypatch: pytest
method = unwrap(api.post)
app_model = SimpleNamespace(id="app-1")
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
session = SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)
_install_workflow_service(monkeypatch, workflow=None)
with app.test_request_context(
@ -118,7 +124,7 @@ def test_instruction_generate_node_missing(app: Flask, monkeypatch: pytest.Monke
method = unwrap(api.post)
app_model = SimpleNamespace(id="app-1")
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
session = SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)
workflow = SimpleNamespace(graph_dict={"nodes": []})
_install_workflow_service(monkeypatch, workflow=workflow)
@ -144,7 +150,7 @@ def test_instruction_generate_code_node(app: Flask, monkeypatch: pytest.MonkeyPa
method = unwrap(api.post)
app_model = SimpleNamespace(id="app-1")
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
session = SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)
workflow = SimpleNamespace(
graph_dict={

View File

@ -130,3 +130,50 @@ class TestAppMCPServerController:
assert response == {"id": "server-1"}
assert status_code == 201
def test_put_binds_server_lookup_to_app_ref(self):
api = AppMCPServerController()
method = unwrap(api.put)
payload = {"id": "server-1", "description": "Updated", "parameters": {"timeout": 30}, "status": "active"}
app = Flask(__name__)
app.config["TESTING"] = True
server = SimpleNamespace(
id="server-1",
tenant_id="tenant-1",
app_id="app-1",
name="Old",
description="Old",
parameters="{}",
status="active",
)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch("controllers.console.app.mcp_server.db.session.scalar", return_value=server) as scalar,
patch("controllers.console.app.mcp_server.db.session.get") as get_mock,
patch("controllers.console.app.mcp_server.db.session.commit") as commit,
patch(
"controllers.console.app.mcp_server.AppMCPServerResponse.model_validate",
return_value=_ValidatedResponse({"id": "server-1"}),
),
):
response = method(
api,
app_model=SimpleNamespace(
id="app-1", tenant_id="tenant-1", name="Demo App", description="App description"
),
)
stmt = scalar.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "app_mcp_servers.id" in statement
assert "app_mcp_servers.tenant_id" in statement
assert "app_mcp_servers.app_id" in statement
assert payload["id"] in compiled.params.values()
assert "tenant-1" in compiled.params.values()
assert "app-1" in compiled.params.values()
get_mock.assert_not_called()
commit.assert_called_once()
assert response == {"id": "server-1"}

View File

@ -108,6 +108,15 @@ def _segment_response_dict():
}
def _bind_dataset_document(dataset, document, dataset_id: str = "ds-1", document_id: str = "doc-1"):
dataset.id = dataset_id
dataset.tenant_id = "tenant-1"
document.id = document_id
document.dataset_id = dataset_id
document.tenant_id = "tenant-1"
return document
def test_segment_response_with_summary():
segment = _segment()
@ -380,6 +389,7 @@ class TestDatasetDocumentSegmentAddApi:
document = MagicMock()
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
_bind_dataset_document(dataset, document)
segment = _segment()
@ -504,6 +514,7 @@ class TestDatasetDocumentSegmentUpdateApi:
document = MagicMock()
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
_bind_dataset_document(dataset, document)
segment = _segment()
@ -519,8 +530,8 @@ class TestDatasetDocumentSegmentUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.scalar",
side_effect=[segment, None],
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -538,6 +549,7 @@ class TestDatasetDocumentSegmentUpdateApi:
"controllers.console.datasets.datasets_segments.SummaryIndexService.get_segment_summary",
return_value=None,
),
patch("models.dataset.db.session.scalar", return_value=None),
patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))),
):
response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1")
@ -576,6 +588,10 @@ class TestDatasetDocumentSegmentUpdateApi:
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting",
return_value=None,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
return_value=None,
),
patch(
"controllers.console.datasets.datasets_segments.ModelManager.get_model_instance",
side_effect=LLMBadRequestError(),
@ -781,13 +797,15 @@ class TestChildChunkAddApi:
api = ChildChunkAddApi()
method = inspect.unwrap(api.get)
dataset = MagicMock()
document = _bind_dataset_document(dataset, MagicMock())
pagination = MagicMock(items=[], total=0, pages=0)
with (
app.test_request_context("/?page=bad&limit="),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=MagicMock(),
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting",
@ -795,10 +813,10 @@ class TestChildChunkAddApi:
),
patch(
"controllers.console.datasets.datasets_segments.DocumentService.get_document",
return_value=MagicMock(),
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.scalar",
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
return_value=MagicMock(),
),
patch(
@ -826,6 +844,7 @@ class TestChildChunkAddApi:
dataset.indexing_technique = "economy"
document = MagicMock()
_bind_dataset_document(dataset, document)
segment = MagicMock()
child_chunk = _child_chunk()
@ -841,7 +860,7 @@ class TestChildChunkAddApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.scalar",
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
return_value=segment,
),
patch(
@ -868,6 +887,7 @@ class TestChildChunkAddApi:
dataset = MagicMock(indexing_technique="economy")
document = MagicMock()
_bind_dataset_document(dataset, document)
segment = MagicMock()
with (
@ -882,7 +902,7 @@ class TestChildChunkAddApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.scalar",
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
return_value=segment,
),
patch(
@ -908,6 +928,7 @@ class TestChildChunkUpdateApi:
dataset = MagicMock()
document = MagicMock()
_bind_dataset_document(dataset, document)
segment = MagicMock()
child_chunk = MagicMock()
@ -922,8 +943,12 @@ class TestChildChunkUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.scalar",
side_effect=[segment, child_chunk],
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.SegmentService.get_child_chunk_by_segment_ref",
return_value=child_chunk,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
@ -947,6 +972,7 @@ class TestChildChunkUpdateApi:
dataset = MagicMock()
document = MagicMock()
_bind_dataset_document(dataset, document)
segment = MagicMock()
child_chunk = MagicMock()
@ -961,8 +987,12 @@ class TestChildChunkUpdateApi:
return_value=document,
),
patch(
"controllers.console.datasets.datasets_segments.db.session.scalar",
side_effect=[segment, child_chunk],
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
return_value=segment,
),
patch(
"controllers.console.datasets.datasets_segments.SegmentService.get_child_chunk_by_segment_ref",
return_value=child_chunk,
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",

View File

@ -245,7 +245,7 @@ class TestTextApi:
api = TextApi()
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(external_user_id="ext")
end_user = SimpleNamespace(id="end-user-1", external_user_id="ext")
with app.test_request_context(
"/text-to-audio",
@ -264,7 +264,7 @@ class TestTextApi:
api = TextApi()
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(external_user_id="ext")
end_user = SimpleNamespace(id="end-user-1", external_user_id="ext")
with app.test_request_context("/text-to-audio", method="POST", json={"text": "hello"}):
with pytest.raises(ProviderQuotaExceededError):

View File

@ -90,6 +90,19 @@ def _child_chunk() -> ChildChunk:
return child_chunk
def _document_for_dataset(
dataset: Dataset, document_id: str = "doc-id", doc_form: str = IndexStructureType.PARAGRAPH_INDEX
):
document = Mock()
document.id = document_id
document.dataset_id = dataset.id
document.tenant_id = dataset.tenant_id
document.indexing_status = "completed"
document.enabled = True
document.doc_form = doc_form
return document
class TestSegmentCreatePayload:
"""Test suite for SegmentCreatePayload Pydantic model."""
@ -868,7 +881,9 @@ class TestSegmentApiGet:
# Arrange
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_doc_svc.get_document.return_value = _document_for_dataset(
mock_dataset, doc_form=IndexStructureType.PARAGRAPH_INDEX
)
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
mock_get_summaries.return_value = {}
mock_dump_segments.return_value = [_segment_response_dict()]
@ -988,7 +1003,7 @@ class TestSegmentApiPost:
mock_dataset.indexing_technique = "economy"
mock_db.session.scalar.return_value = mock_dataset
mock_doc = Mock()
mock_doc = _document_for_dataset(mock_dataset)
mock_doc.indexing_status = "completed"
mock_doc.enabled = True
mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
@ -1040,7 +1055,7 @@ class TestSegmentApiPost:
mock_dataset.indexing_technique = "economy"
mock_db.session.scalar.return_value = mock_dataset
mock_doc = Mock()
mock_doc = _document_for_dataset(mock_dataset)
mock_doc.indexing_status = "completed"
mock_doc.enabled = True
mock_doc_svc.get_document.return_value = mock_doc
@ -1082,7 +1097,7 @@ class TestSegmentApiPost:
mock_db.session.scalar.return_value = mock_dataset
mock_doc = Mock()
mock_doc = _document_for_dataset(mock_dataset)
mock_doc.indexing_status = "indexing" # Not completed
mock_doc_svc.get_document.return_value = mock_doc
@ -1134,10 +1149,10 @@ class TestDatasetSegmentApiDelete:
mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc = Mock()
mock_doc = _document_for_dataset(mock_dataset)
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
mock_seg_svc.delete_segment.return_value = None
# Act
@ -1177,13 +1192,13 @@ class TestDatasetSegmentApiDelete:
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_doc = Mock()
mock_doc = _document_for_dataset(mock_dataset)
mock_doc.indexing_status = "completed"
mock_doc.enabled = True
mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = None # Segment not found
mock_seg_svc.get_segment_by_ref.return_value = None # Segment not found
# Act & Assert
with app.test_request_context(
@ -1329,8 +1344,10 @@ class TestDatasetSegmentApiUpdate:
mock_dataset.indexing_technique = "economy"
mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_doc_svc.get_document.return_value = _document_for_dataset(
mock_dataset, doc_form=IndexStructureType.PARAGRAPH_INDEX
)
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
updated = Mock()
updated.id = "updated-seg"
mock_seg_svc.update_segment.return_value = updated
@ -1419,8 +1436,8 @@ class TestDatasetSegmentApiUpdate:
mock_dataset.indexing_technique = "economy"
mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = None
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
mock_seg_svc.get_segment_by_ref.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id",
@ -1470,9 +1487,9 @@ class TestDatasetSegmentApiGetSingle:
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_doc = _document_for_dataset(mock_dataset, doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
mock_get_summary.return_value = None
mock_dump_segment.return_value = _segment_response_dict()
@ -1517,9 +1534,9 @@ class TestDatasetSegmentApiGetSingle:
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_doc = _document_for_dataset(mock_dataset, doc_form=IndexStructureType.PARAGRAPH_INDEX)
mock_doc_svc.get_document.return_value = mock_doc
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
mock_summary_record = Mock(summary_content="This is the segment summary")
mock_get_summary.return_value = mock_summary_record
mock_dump_segment.return_value = _segment_response_dict("This is the segment summary")
@ -1619,8 +1636,8 @@ class TestDatasetSegmentApiGetSingle:
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_dataset_svc.check_dataset_model_setting.return_value = None
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = None
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
mock_seg_svc.get_segment_by_ref.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id",
@ -1660,8 +1677,8 @@ class TestChildChunkApiGet:
"""Test successful child chunk list retrieval."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = Mock()
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
mock_seg_svc.get_segment_by_ref.return_value = Mock()
mock_pagination = Mock()
mock_pagination.items = [_child_chunk(), _child_chunk()]
@ -1759,8 +1776,8 @@ class TestChildChunkApiGet:
"""Test 404 when segment not found."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = None
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
mock_seg_svc.get_segment_by_ref.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks",
@ -1822,8 +1839,8 @@ class TestChildChunkApiPost:
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_dataset.indexing_technique = "economy"
mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = Mock()
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
mock_seg_svc.get_segment_by_ref.return_value = Mock()
mock_child = _child_chunk()
mock_seg_svc.create_child_chunk.return_value = mock_child
@ -1900,8 +1917,8 @@ class TestChildChunkApiPost:
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_seg_svc.get_segment_by_id.return_value = None
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
mock_seg_svc.get_segment_by_ref.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks",
@ -1954,19 +1971,19 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_doc = Mock()
mock_doc = _document_for_dataset(mock_dataset)
mock_doc_svc.get_document.return_value = mock_doc
segment_id = str(uuid.uuid4())
mock_segment = Mock()
mock_segment.id = segment_id
mock_segment.document_id = "doc-id"
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
child_chunk_id = str(uuid.uuid4())
mock_child = Mock()
mock_child.segment_id = segment_id
mock_seg_svc.get_child_chunk_by_id.return_value = mock_child
mock_seg_svc.get_child_chunk_by_segment_ref.return_value = mock_child
mock_seg_svc.delete_child_chunk.return_value = None
with app.test_request_context(
@ -2003,14 +2020,14 @@ class TestDatasetChildChunkApiDelete:
"""Test 404 when child chunk not found."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
segment_id = str(uuid.uuid4())
mock_segment = Mock()
mock_segment.id = segment_id
mock_segment.document_id = "doc-id"
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_seg_svc.get_child_chunk_by_id.return_value = None
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
mock_seg_svc.get_child_chunk_by_segment_ref.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{segment_id}/child_chunks/cc-id",
@ -2044,13 +2061,10 @@ class TestDatasetChildChunkApiDelete:
"""Test 404 when segment does not belong to the document."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
segment_id = str(uuid.uuid4())
mock_segment = Mock()
mock_segment.id = segment_id
mock_segment.document_id = "different-doc-id"
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_seg_svc.get_segment_by_ref.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{segment_id}/child_chunks/cc-id",
@ -2084,17 +2098,15 @@ class TestDatasetChildChunkApiDelete:
"""Test 404 when child chunk does not belong to the segment."""
mock_account_fn.return_value = (Mock(), mock_tenant.id)
mock_db.session.scalar.return_value = mock_dataset
mock_doc_svc.get_document.return_value = Mock()
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
segment_id = str(uuid.uuid4())
mock_segment = Mock()
mock_segment.id = segment_id
mock_segment.document_id = "doc-id"
mock_seg_svc.get_segment_by_id.return_value = mock_segment
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
mock_child = Mock()
mock_child.segment_id = "different-segment-id"
mock_seg_svc.get_child_chunk_by_id.return_value = mock_child
mock_seg_svc.get_child_chunk_by_segment_ref.return_value = None
with app.test_request_context(
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{segment_id}/child_chunks/cc-id",

View File

@ -422,6 +422,13 @@ class TestLLMGenerator:
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert result == {"modified": "prompt"}
stmt = mock_scalar.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "messages.app_id" in statement
assert "apps.tenant_id" in statement
assert "flow_id" in compiled.params.values()
assert "tenant_id" in compiled.params.values()
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
@ -439,12 +446,26 @@ class TestLLMGenerator:
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
)
assert result == {"modified": "prompt"}
stmt = mock_scalar.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "messages.app_id" in statement
assert "apps.tenant_id" in statement
assert "flow_id" in compiled.params.values()
assert "tenant_id" in compiled.params.values()
def test_instruction_modify_workflow_app_not_found(self):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.scalar.return_value = None
with pytest.raises(ValueError, match="App not found."):
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
stmt = mock_session.return_value.scalar.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "apps.id" in statement
assert "apps.tenant_id" in statement
assert "f" in compiled.params.values()
assert "t" in compiled.params.values()
def test_instruction_modify_workflow_no_workflow(self):
with patch("extensions.ext_database.db.session") as mock_session:

View File

@ -1,111 +0,0 @@
"""Tests for the uuidv7 SQL migration's PostgreSQL 18 compatibility guard.
The migration file name is not a valid Python identifier (it starts with a date and
contains hyphens), so it is loaded directly from its path. The ``models`` import at the
top of the migration is stubbed because the migration never uses it during
``upgrade()``/``downgrade()`` and pulling in the real package would require a full app
context.
"""
import importlib.util
import sys
import types
from pathlib import Path
from unittest import mock
import pytest
MIGRATION_PATH = (
Path(__file__).resolve().parents[3]
/ "migrations"
/ "versions"
/ "2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py"
)
def _load_migration():
# The migration does `import models as models` but never references it, so a stub is
# enough and keeps the test free of any database/app configuration.
sys.modules.setdefault("models", types.ModuleType("models"))
spec = importlib.util.spec_from_file_location("uuidv7_pg18_migration_under_test", MIGRATION_PATH)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def _make_bind(dialect_name):
bind = mock.MagicMock()
bind.dialect.name = dialect_name
return bind
def _executed_sql(fake_op):
return [str(call.args[0]) for call in fake_op.execute.call_args_list]
@pytest.fixture
def migration():
return _load_migration()
def test_upgrade_creates_both_functions_when_native_uuidv7_absent(migration):
# PostgreSQL 13 to 17: no native pg_catalog.uuidv7(), so both functions are created.
# The DO block contains the CREATE FUNCTION guarded by IF NOT EXISTS, and
# uuidv7_boundary is created unconditionally.
bind = _make_bind("postgresql")
with mock.patch.object(migration, "op") as fake_op:
fake_op.get_bind.return_value = bind
migration.upgrade()
sql = _executed_sql(fake_op)
assert any("CREATE FUNCTION public.uuidv7()" in stmt for stmt in sql)
assert any("CREATE FUNCTION public.uuidv7_boundary(timestamptz)" in stmt for stmt in sql)
def test_upgrade_skips_uuidv7_but_keeps_boundary_when_native_present(migration):
# PostgreSQL 18: native pg_catalog.uuidv7() exists, so the DO block must guard
# the CREATE FUNCTION with an IF NOT EXISTS check against pg_catalog.
# uuidv7_boundary is still missing and has to be created unconditionally.
bind = _make_bind("postgresql")
with mock.patch.object(migration, "op") as fake_op:
fake_op.get_bind.return_value = bind
migration.upgrade()
sql = _executed_sql(fake_op)
# The DO block must contain the pg_catalog existence check.
do_block = next((stmt for stmt in sql if "DO $do$" in stmt), None)
assert do_block is not None
assert "pg_catalog" in do_block
assert "uuidv7" in do_block
assert "IF NOT EXISTS" in do_block
# uuidv7_boundary is always created (not guarded by the DO block).
assert any("CREATE FUNCTION public.uuidv7_boundary(timestamptz)" in stmt for stmt in sql)
def test_upgrade_is_noop_on_non_postgres(migration):
bind = _make_bind("sqlite")
with mock.patch.object(migration, "op") as fake_op:
fake_op.get_bind.return_value = bind
migration.upgrade()
fake_op.execute.assert_not_called()
def test_downgrade_uses_if_exists_and_public_schema(migration):
bind = _make_bind("postgresql")
with mock.patch.object(migration, "op") as fake_op:
fake_op.get_bind.return_value = bind
migration.downgrade()
sql = _executed_sql(fake_op)
assert "DROP FUNCTION IF EXISTS public.uuidv7()" in sql
assert "DROP FUNCTION IF EXISTS public.uuidv7_boundary(timestamptz)" in sql
def test_downgrade_is_noop_on_non_postgres(migration):
bind = _make_bind("sqlite")
with mock.patch.object(migration, "op") as fake_op:
fake_op.get_bind.return_value = bind
migration.downgrade()
fake_op.execute.assert_not_called()

View File

@ -12,6 +12,7 @@ from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate, Pipeli
from models.workflow import Workflow
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_ref_service import WorkflowRefService
@pytest.fixture
@ -369,6 +370,38 @@ def test_update_workflow_returns_none_when_not_found(
assert result is None
def test_update_workflow_with_ref_scopes_lookup_to_pipeline(
mocker: MockerFixture, rag_pipeline_service: RagPipelineService
) -> None:
workflow = SimpleNamespace(
id="wf-1", marked_name="", marked_comment="", updated_by=None, updated_at=None, disallowed="original"
)
pipeline = SimpleNamespace(id="pipeline-1", tenant_id="t1")
workflow_ref = WorkflowRefService.create_pipeline_workflow_ref(pipeline, "wf-1")
session = mocker.Mock()
session.scalar.return_value = workflow
result = rag_pipeline_service.update_workflow(
session=session,
workflow_id="wf-1",
tenant_id="t1",
account_id="u1",
data={"marked_name": "v1"},
workflow_ref=workflow_ref,
)
stmt = session.scalar.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "workflows.id" in statement
assert "workflows.tenant_id" in statement
assert "workflows.app_id" in statement
assert "wf-1" in compiled.params.values()
assert "t1" in compiled.params.values()
assert "pipeline-1" in compiled.params.values()
assert result is workflow
# --- get_rag_pipeline_paginate_workflow_runs ---
@ -1627,6 +1660,8 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(
def __init__(self):
self._values = {
("sys", "invoke_from"): SimpleNamespace(value=InvokeFrom.PUBLISHED_PIPELINE),
("sys", "app_id"): SimpleNamespace(value="pipeline-1"),
("sys", "dataset_id"): SimpleNamespace(value="dataset-1"),
("sys", "document_id"): SimpleNamespace(value="doc-1"),
}
@ -1660,7 +1695,8 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(
)
document = SimpleNamespace(indexing_status="waiting", error=None)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
scalar_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=document)
get_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get")
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@ -1672,6 +1708,19 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(
)
assert result.status == WorkflowNodeExecutionStatus.FAILED
stmt = scalar_mock.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "documents.id" in statement
assert "documents.tenant_id" in statement
assert "documents.dataset_id" in statement
assert "datasets.tenant_id" in statement
assert "datasets.pipeline_id" in statement
assert "doc-1" in compiled.params.values()
assert "t1" in compiled.params.values()
assert "dataset-1" in compiled.params.values()
assert "pipeline-1" in compiled.params.values()
get_mock.assert_not_called()
assert document.indexing_status == "error"
assert document.error == "boom"
add_mock.assert_called_once_with(document)

View File

@ -41,9 +41,10 @@ def _make_message(message_id: str = "msg-1", app_id: str = "app-1") -> MagicMock
return message
def _make_annotation(annotation_id: str = "ann-1") -> MagicMock:
def _make_annotation(annotation_id: str = "ann-1", app_id: str = "app-1") -> MagicMock:
annotation = MagicMock(spec=MessageAnnotation)
annotation.id = annotation_id
annotation.app_id = app_id
annotation.content = ""
annotation.question = ""
annotation.question_text = ""
@ -66,6 +67,15 @@ def _make_file(content: bytes) -> FileStorage:
return FileStorage(stream=BytesIO(content))
def _assert_statement_binds_annotation(stmt: Any, annotation_id: str, app_id: str) -> None:
compiled = stmt.compile()
statement = str(compiled)
assert "message_annotations.id" in statement
assert "message_annotations.app_id" in statement
assert annotation_id in compiled.params.values()
assert app_id in compiled.params.values()
class TestAppAnnotationServiceUpInsert:
"""Test suite for up_insert_app_annotation_from_message."""
@ -541,8 +551,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
mock_db.session.scalar.return_value = app
mock_db.session.get.return_value = None
mock_db.session.scalar.side_effect = [app, None]
# Act & Assert
with pytest.raises(NotFound):
@ -576,8 +585,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
mock_db.session.scalar.return_value = app
mock_db.session.get.return_value = annotation
mock_db.session.scalar.side_effect = [app, annotation]
# Act & Assert
with pytest.raises(ValueError):
@ -598,8 +606,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.update_annotation_to_index_task") as mock_task,
):
mock_db.session.scalar.side_effect = [app, setting]
mock_db.session.get.return_value = annotation
mock_db.session.scalar.side_effect = [app, annotation, setting]
# Act
result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id, mock_db.session)
@ -608,6 +615,8 @@ class TestAppAnnotationServiceDirectManipulation:
assert result == annotation
assert annotation.content == "hello"
assert annotation.question == "q1"
_assert_statement_binds_annotation(mock_db.session.scalar.call_args_list[1].args[0], annotation.id, app.id)
mock_db.session.get.assert_not_called()
mock_db.session.commit.assert_called_once()
mock_task.delay.assert_called_once_with(
annotation.id,
@ -632,8 +641,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.delete_annotation_index_task") as mock_task,
):
mock_db.session.scalar.side_effect = [app, setting]
mock_db.session.get.return_value = annotation
mock_db.session.scalar.side_effect = [app, annotation, setting]
scalars_result = MagicMock()
scalars_result.all.return_value = [history1, history2]
@ -643,6 +651,8 @@ class TestAppAnnotationServiceDirectManipulation:
AppAnnotationService.delete_app_annotation(app.id, annotation.id, mock_db.session)
# Assert
_assert_statement_binds_annotation(mock_db.session.scalar.call_args_list[1].args[0], annotation.id, app.id)
mock_db.session.get.assert_not_called()
mock_db.session.delete.assert_any_call(annotation)
mock_db.session.delete.assert_any_call(history1)
mock_db.session.delete.assert_any_call(history2)
@ -679,8 +689,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
mock_db.session.scalar.return_value = app
mock_db.session.get.return_value = None
mock_db.session.scalar.side_effect = [app, None]
# Act & Assert
with pytest.raises(NotFound):
@ -748,6 +757,13 @@ class TestAppAnnotationServiceDirectManipulation:
# Assert
assert result == {"deleted_count": 2}
fetch_stmt = mock_db.session.execute.call_args_list[0].args[0]
compiled = fetch_stmt.compile()
statement = str(compiled)
assert "message_annotations.id IN" in statement
assert "message_annotations.app_id" in statement
assert ["ann-1", "ann-2"] in compiled.params.values()
assert app.id in compiled.params.values()
mock_task.delay.assert_called_once_with(annotation1.id, app.id, tenant_id, setting.collection_binding_id)
mock_db.session.commit.assert_called_once()
@ -1121,8 +1137,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
mock_db.session.scalar.return_value = app
mock_db.session.get.return_value = annotation
mock_db.session.scalar.side_effect = [app, annotation]
mock_db.paginate.return_value = pagination
# Act
@ -1131,6 +1146,8 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
# Assert
assert items == ["h1"]
assert total == 2
_assert_statement_binds_annotation(mock_db.session.scalar.call_args_list[1].args[0], annotation.id, app.id)
mock_db.session.get.assert_not_called()
def test_get_annotation_hit_histories_should_raise_not_found_when_annotation_missing(self) -> None:
"""Test missing annotation raises NotFound."""
@ -1142,8 +1159,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
mock_db.session.scalar.return_value = app
mock_db.session.get.return_value = None
mock_db.session.scalar.side_effect = [app, None]
# Act & Assert
with pytest.raises(NotFound):

View File

@ -523,7 +523,7 @@ class TestAudioServiceTTS:
message_id = "00000000-0000-0000-0000-000000000001"
message = factory.create_message_mock(message_id=message_id, answer="Message answer")
session = MagicMock()
session.get.return_value = message
session.scalar.return_value = message
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
@ -536,11 +536,25 @@ class TestAudioServiceTTS:
session=session,
message_id=message_id,
voice="message-voice",
message_end_user_id="end-user-1",
message_account_id="account-1",
)
# Assert
assert result == b"message audio"
session.get.assert_called_once_with(Message, message_id)
session.scalar.assert_called_once()
session.get.assert_not_called()
stmt = session.scalar.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "messages.id" in statement
assert "messages.app_id" in statement
assert "messages.from_end_user_id" in statement
assert "messages.from_account_id" in statement
assert message_id in compiled.params.values()
assert app.id in compiled.params.values()
assert "end-user-1" in compiled.params.values()
assert "account-1" in compiled.params.values()
mock_model_instance.invoke_tts.assert_called_once_with(
content_text="Message answer",
voice="message-voice",

View File

@ -103,6 +103,33 @@ class TestDocumentServiceMutations:
assert DocumentService.check_archived(document) is expected
def test_delete_documents_limits_query_and_cleanup_to_dataset_ref(self):
dataset = _make_dataset(dataset_id="dataset-1", tenant_id="tenant-1")
dataset.doc_form = "paragraph_index"
document = _make_document(document_id="doc-1", dataset_id=dataset.id, tenant_id=dataset.tenant_id)
document.data_source_info_dict = {}
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.batch_clean_document_task") as clean_task,
):
mock_db.session.scalars.return_value.all.return_value = [document]
DocumentService.delete_documents(dataset, ["doc-1", "other-doc"])
stmt = mock_db.session.scalars.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "documents.id IN" in statement
assert "documents.tenant_id" in statement
assert "documents.dataset_id" in statement
assert ["doc-1", "other-doc"] in compiled.params.values()
assert dataset.tenant_id in compiled.params.values()
assert dataset.id in compiled.params.values()
mock_db.session.delete.assert_called_once_with(document)
mock_db.session.commit.assert_called_once()
clean_task.delay.assert_called_once_with(["doc-1"], dataset.id, dataset.doc_form, [])
def test_rename_document_raises_when_dataset_is_missing(self, rename_account_context):
with patch.object(DatasetService, "get_dataset", return_value=None):
with pytest.raises(ValueError, match="Dataset not found"):

View File

@ -1,5 +1,7 @@
"""Unit tests for SegmentService behaviors in dataset_service."""
from services.dataset_ref_service import DatasetRef, DatasetRefService
from .dataset_service_test_helpers import (
Account,
ChildChunk,
@ -24,6 +26,38 @@ from .dataset_service_test_helpers import (
)
def _make_segment_ref(segment_id: str = "segment-1"):
dataset = _make_dataset()
document = _make_document(dataset_id=dataset.id, tenant_id=dataset.tenant_id)
dataset_ref = DatasetRefService.create_dataset_ref(dataset)
document_ref = DatasetRefService.create_document_ref(dataset_ref, document)
assert document_ref is not None
return DatasetRefService.create_segment_ref(document_ref, segment_id)
class TestDatasetRefService:
"""Unit tests for typed dataset resource refs."""
def test_dataset_ref_requires_service_ctor_token(self):
with pytest.raises(ValueError, match="DatasetRef must be created by DatasetRefService"):
DatasetRef("tenant-1", "dataset-1", object())
def test_create_document_ref_rejects_document_outside_dataset(self):
dataset = _make_dataset(dataset_id="dataset-1", tenant_id="tenant-1")
document = _make_document(document_id="doc-1", dataset_id="other-dataset", tenant_id="tenant-1")
dataset_ref = DatasetRefService.create_dataset_ref(dataset)
assert DatasetRefService.create_document_ref(dataset_ref, document) is None
def test_create_segment_ref_carries_full_parent_chain(self):
segment_ref = _make_segment_ref()
assert segment_ref.tenant_id == "tenant-1"
assert segment_ref.dataset_id == "dataset-1"
assert segment_ref.document_id == "doc-1"
assert segment_ref.segment_id == "segment-1"
class TestSegmentServiceChildChunks:
"""Unit tests for child-chunk CRUD helpers."""
@ -257,6 +291,23 @@ class TestSegmentServiceQueries:
assert result is None
def test_get_child_chunk_by_segment_ref_uses_full_ownership_chain(self):
child_chunk = _make_child_chunk()
segment_ref = _make_segment_ref()
with patch("services.dataset_service.db") as mock_db:
mock_db.session.scalar.return_value = child_chunk
result = SegmentService.get_child_chunk_by_segment_ref("child-a", segment_ref)
assert result is child_chunk
stmt = mock_db.session.scalar.call_args.args[0]
sql = str(stmt.compile(compile_kwargs={"literal_binds": True}))
assert "child_chunks.id = 'child-a'" in sql
assert "child_chunks.tenant_id = 'tenant-1'" in sql
assert "child_chunks.dataset_id = 'dataset-1'" in sql
assert "child_chunks.document_id = 'doc-1'" in sql
assert "child_chunks.segment_id = 'segment-1'" in sql
def test_get_segments_uses_status_and_keyword_filters(self):
paginated = SimpleNamespace(items=["segment"], total=1)
@ -304,6 +355,32 @@ class TestSegmentServiceQueries:
assert result is None
def test_get_segment_by_ref_uses_full_ownership_chain(self):
segment = DocumentSegment(
tenant_id="tenant-1",
dataset_id="dataset-1",
document_id="doc-1",
position=1,
content="segment",
word_count=7,
tokens=2,
created_by="user-1",
)
segment.id = "segment-1"
segment_ref = _make_segment_ref()
with patch("services.dataset_service.db") as mock_db:
mock_db.session.scalar.return_value = segment
result = SegmentService.get_segment_by_ref(segment_ref)
assert result is segment
stmt = mock_db.session.scalar.call_args.args[0]
sql = str(stmt.compile(compile_kwargs={"literal_binds": True}))
assert "document_segments.id = 'segment-1'" in sql
assert "document_segments.tenant_id = 'tenant-1'" in sql
assert "document_segments.dataset_id = 'dataset-1'" in sql
assert "document_segments.document_id = 'doc-1'" in sql
def test_get_segments_by_document_and_dataset_returns_scalars_result(self):
segment = DocumentSegment(
tenant_id="tenant-1",

View File

@ -5,7 +5,7 @@ from pytest_mock import MockerFixture
from werkzeug.exceptions import NotFound
from models.enums import TagType
from services.tag_service import TagBindingCreatePayload, TagBindingDeletePayload, TagService
from services.tag_service import TagBindingCreatePayload, TagBindingDeletePayload, TagService, UpdateTagPayload
@pytest.fixture
@ -78,6 +78,71 @@ def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker: MockerF
db_session.commit.assert_not_called()
def test_update_tags_scopes_lookup_to_current_tenant_and_type(current_user, db_session):
tag = SimpleNamespace(id="tag-1", name="old", type=TagType.KNOWLEDGE)
db_session.scalar.side_effect = [tag, None]
result = TagService.update_tags(UpdateTagPayload(name="new"), "tag-1", db_session, tag_type=TagType.KNOWLEDGE)
stmt = db_session.scalar.call_args_list[0].args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "tags.id" in statement
assert "tags.tenant_id" in statement
assert "tags.type" in statement
assert "tag-1" in compiled.params.values()
assert current_user.current_tenant_id in compiled.params.values()
assert TagType.KNOWLEDGE in compiled.params.values()
assert result is tag
assert tag.name == "new"
db_session.commit.assert_called_once()
def test_get_tag_binding_count_scopes_lookup_to_current_tenant_and_type(current_user, db_session):
db_session.scalar.return_value = 3
result = TagService.get_tag_binding_count("tag-1", db_session, tag_type=TagType.KNOWLEDGE)
stmt = db_session.scalar.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "tag_bindings.tag_id" in statement
assert "tags.tenant_id" in statement
assert "tags.type" in statement
assert "tag-1" in compiled.params.values()
assert current_user.current_tenant_id in compiled.params.values()
assert TagType.KNOWLEDGE in compiled.params.values()
assert result == 3
def test_delete_tag_scopes_lookup_and_bindings_to_current_tenant(current_user, db_session):
tag = SimpleNamespace(id="tag-1", name="old", type=TagType.KNOWLEDGE)
binding = SimpleNamespace(id="binding-1")
db_session.scalar.return_value = tag
db_session.scalars.return_value.all.return_value = [binding]
TagService.delete_tag("tag-1", db_session, tag_type=TagType.KNOWLEDGE)
tag_stmt = db_session.scalar.call_args.args[0]
tag_compiled = tag_stmt.compile()
assert "tags.id" in str(tag_compiled)
assert "tags.tenant_id" in str(tag_compiled)
assert "tags.type" in str(tag_compiled)
assert "tag-1" in tag_compiled.params.values()
assert current_user.current_tenant_id in tag_compiled.params.values()
assert TagType.KNOWLEDGE in tag_compiled.params.values()
binding_stmt = db_session.scalars.call_args.args[0]
binding_compiled = binding_stmt.compile()
assert "tag_bindings.tag_id" in str(binding_compiled)
assert "tag_bindings.tenant_id" in str(binding_compiled)
assert "tag-1" in binding_compiled.params.values()
assert current_user.current_tenant_id in binding_compiled.params.values()
db_session.delete.assert_any_call(tag)
db_session.delete.assert_any_call(binding)
db_session.commit.assert_called_once()
def test_get_target_ids_by_tag_ids_returns_empty_without_query_for_empty_input(db_session):
result = TagService.get_target_ids_by_tag_ids(TagType.SNIPPET, "tenant-1", [], db_session)

View File

@ -36,6 +36,7 @@ from models.model import App, AppMode
from models.workflow import Workflow, WorkflowType
from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError
from services.errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
from services.workflow_ref_service import WorkflowRefService
from services.workflow_service import (
WorkflowService,
_rebuild_file_for_user_inputs_in_start_node,
@ -1052,6 +1053,39 @@ class TestWorkflowService:
assert result is None
def test_update_workflow_with_ref_scopes_lookup_to_app(self, workflow_service: WorkflowService):
"""Test update_workflow includes the trusted app owner in the lookup."""
workflow_id = "workflow-123"
tenant_id = "tenant-456"
app_id = "app-789"
account_id = "user-123"
workflow_ref = WorkflowRefService.create_app_workflow_ref(
SimpleNamespace(id=app_id, tenant_id=tenant_id), workflow_id
)
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id)
mock_session = MagicMock()
mock_session.scalar.return_value = mock_workflow
result = workflow_service.update_workflow(
session=mock_session,
workflow_id=workflow_id,
tenant_id=tenant_id,
account_id=account_id,
data={"marked_name": "Updated Name"},
workflow_ref=workflow_ref,
)
stmt = mock_session.scalar.call_args.args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "workflows.id" in statement
assert "workflows.tenant_id" in statement
assert "workflows.app_id" in statement
assert workflow_id in compiled.params.values()
assert tenant_id in compiled.params.values()
assert app_id in compiled.params.values()
assert result == mock_workflow
# ==================== Delete Workflow Tests ====================
# These tests verify workflow deletion with safety checks
@ -1085,6 +1119,34 @@ class TestWorkflowService:
assert result is True
mock_session.delete.assert_called_once_with(mock_workflow)
def test_delete_workflow_with_ref_scopes_lookup_to_app(self, workflow_service: WorkflowService):
"""Test delete_workflow includes the trusted app owner in the lookup."""
workflow_id = "workflow-123"
tenant_id = "tenant-456"
app_id = "app-789"
workflow_ref = WorkflowRefService.create_app_workflow_ref(
SimpleNamespace(id=app_id, tenant_id=tenant_id), workflow_id
)
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
mock_session = MagicMock()
mock_session.scalar.side_effect = [mock_workflow, None, None]
result = workflow_service.delete_workflow(
session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id, workflow_ref=workflow_ref
)
stmt = mock_session.scalar.call_args_list[0].args[0]
compiled = stmt.compile()
statement = str(compiled)
assert "workflows.id" in statement
assert "workflows.tenant_id" in statement
assert "workflows.app_id" in statement
assert workflow_id in compiled.params.values()
assert tenant_id in compiled.params.values()
assert app_id in compiled.params.values()
assert result is True
mock_session.delete.assert_called_once_with(mock_workflow)
def test_delete_workflow_draft_raises_error(self, workflow_service: WorkflowService):
"""
Test delete_workflow raises error when trying to delete draft.