mirror of
https://github.com/langgenius/dify.git
synced 2026-06-30 10:57:47 +08:00
Compare commits
2 Commits
main
...
fix/resour
| Author | SHA1 | Date | |
|---|---|---|---|
| 41b52bbefb | |||
| 1f72e9799b |
@ -3,6 +3,7 @@ from typing import Any, Literal
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.fields import SimpleDataResponse
|
||||
@ -216,7 +217,9 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = session.get(App, args.flow_id)
|
||||
app = session.scalar(
|
||||
select(App).where(App.id == args.flow_id, App.tenant_id == current_tenant_id).limit(1)
|
||||
)
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app, session=session)
|
||||
|
||||
@ -26,6 +26,7 @@ from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import App, AppMCPServer
|
||||
from services.app_ref_service import AppRefService
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
@ -146,7 +147,17 @@ class AppMCPServerController(Resource):
|
||||
@get_app_model
|
||||
def put(self, app_model: App):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.get(AppMCPServer, payload.id)
|
||||
app_ref = AppRefService.create_app_ref(app_model)
|
||||
server_ref = AppRefService.create_mcp_server_ref(app_ref, payload.id)
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(
|
||||
AppMCPServer.id == server_ref.server_id,
|
||||
AppMCPServer.tenant_id == server_ref.tenant_id,
|
||||
AppMCPServer.app_id == server_ref.app_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not server:
|
||||
raise NotFound()
|
||||
|
||||
|
||||
@ -78,6 +78,7 @@ from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_ref_service import WorkflowRefService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -1406,6 +1407,7 @@ class WorkflowByIdApi(Resource):
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_ref = WorkflowRefService.create_app_workflow_ref(app_model, workflow_id)
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
@ -1415,6 +1417,7 @@ class WorkflowByIdApi(Resource):
|
||||
tenant_id=app_model.tenant_id,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
workflow_ref=workflow_ref,
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
@ -1434,12 +1437,16 @@ class WorkflowByIdApi(Resource):
|
||||
Delete workflow
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
workflow_ref = WorkflowRefService.create_app_workflow_ref(app_model, workflow_id)
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
try:
|
||||
workflow_service.delete_workflow(
|
||||
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
workflow_ref=workflow_ref,
|
||||
)
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -58,8 +58,9 @@ from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import dump_response, escape_like_pattern
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_ref_service import DatasetRefService, SegmentRef
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
@ -162,6 +163,21 @@ register_response_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _get_segment_for_document(
|
||||
dataset: Dataset, document: Document, segment_id: str
|
||||
) -> tuple[SegmentRef, DocumentSegment]:
|
||||
dataset_ref = DatasetRefService.create_dataset_ref(dataset)
|
||||
document_ref = DatasetRefService.create_document_ref(dataset_ref, document)
|
||||
if document_ref is None:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_ref = DatasetRefService.create_segment_ref(document_ref, segment_id)
|
||||
segment = SegmentService.get_segment_by_ref(segment_ref)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
return segment_ref, segment
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@ -465,6 +481,13 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user, db.session)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# check embedding model setting
|
||||
try:
|
||||
@ -481,22 +504,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user, db.session)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
# validate args
|
||||
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
@ -537,15 +546,6 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@ -553,6 +553,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user, db.session)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
segment_id_str = str(segment_id)
|
||||
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return "", 204
|
||||
|
||||
@ -659,17 +661,12 @@ class ChildChunkAddApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user, db.session)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
try:
|
||||
@ -686,10 +683,8 @@ class ChildChunkAddApi(Resource):
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user, db.session)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
segment_id_str = str(segment_id)
|
||||
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
# validate args
|
||||
try:
|
||||
payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
|
||||
@ -719,15 +714,8 @@ class ChildChunkAddApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
_get_segment_for_document(dataset, document, segment_id_str)
|
||||
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
|
||||
|
||||
page = args.page
|
||||
@ -776,15 +764,6 @@ class ChildChunkAddApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@ -792,6 +771,8 @@ class ChildChunkAddApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user, db.session)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
segment_id_str = str(segment_id)
|
||||
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
# validate args
|
||||
payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
|
||||
try:
|
||||
@ -835,29 +816,6 @@ class ChildChunkUpdateApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == child_chunk_id_str,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@ -865,6 +823,12 @@ class ChildChunkUpdateApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user, db.session)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
segment_id_str = str(segment_id)
|
||||
segment_ref, _ = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
child_chunk = SegmentService.get_child_chunk_by_segment_ref(child_chunk_id_str, segment_ref)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
try:
|
||||
SegmentService.delete_child_chunk(child_chunk, dataset)
|
||||
except ChildChunkDeleteIndexServiceError as e:
|
||||
@ -903,29 +867,6 @@ class ChildChunkUpdateApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == child_chunk_id_str,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@ -933,6 +874,12 @@ class ChildChunkUpdateApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user, db.session)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
segment_id_str = str(segment_id)
|
||||
segment_ref, segment = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
child_chunk = SegmentService.get_child_chunk_by_segment_ref(child_chunk_id_str, segment_ref)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
# validate args
|
||||
try:
|
||||
payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -64,6 +64,7 @@ from services.rag_pipeline.pipeline_generate_service import PipelineGenerateServ
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService
|
||||
from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTransformService
|
||||
from services.workflow_ref_service import WorkflowRefService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -738,6 +739,7 @@ class RagPipelineByIdApi(Resource):
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_ref = WorkflowRefService.create_pipeline_workflow_ref(pipeline, workflow_id)
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
@ -747,6 +749,7 @@ class RagPipelineByIdApi(Resource):
|
||||
tenant_id=pipeline.tenant_id,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
workflow_ref=workflow_ref,
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
@ -769,6 +772,7 @@ class RagPipelineByIdApi(Resource):
|
||||
abort(400, description=f"Cannot delete workflow that is currently in use by pipeline '{pipeline.id}'")
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_ref = WorkflowRefService.create_pipeline_workflow_ref(pipeline, workflow_id)
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
try:
|
||||
@ -776,6 +780,7 @@ class RagPipelineByIdApi(Resource):
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
workflow_ref=workflow_ref,
|
||||
)
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
@ -425,6 +425,7 @@ class TrialChatTextApi(TrialAppResource):
|
||||
text=text,
|
||||
voice=voice,
|
||||
message_id=message_id,
|
||||
message_account_id=current_user.id,
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
|
||||
return response
|
||||
|
||||
@ -117,7 +117,8 @@ def _enforce_snippet_tag_rbac_by_tag_id(tag_id: str) -> None:
|
||||
if not dify_config.RBAC_ENABLED:
|
||||
return
|
||||
|
||||
tag_type = db.session.scalar(select(Tag.type).where(Tag.id == tag_id).limit(1))
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tag_type = db.session.scalar(select(Tag.type).where(Tag.id == tag_id, Tag.tenant_id == current_tenant_id).limit(1))
|
||||
_enforce_snippet_tag_rbac_if_needed(tag_type)
|
||||
|
||||
|
||||
|
||||
@ -184,6 +184,7 @@ class TextApi(Resource):
|
||||
voice=voice,
|
||||
end_user=end_user.external_user_id,
|
||||
message_id=message_id,
|
||||
message_end_user_id=end_user.id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -939,9 +939,11 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag_id = payload.tag_id
|
||||
tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id, db.session)
|
||||
tag = TagService.update_tags(
|
||||
UpdateTagServicePayload(name=payload.name), tag_id, db.session, tag_type=TagType.KNOWLEDGE
|
||||
)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id, db.session)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id, db.session, tag_type=TagType.KNOWLEDGE)
|
||||
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
@ -971,7 +973,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
def delete(self, _):
|
||||
"""Delete a knowledge type tag."""
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id, db.session)
|
||||
TagService.delete_tag(payload.tag_id, db.session, tag_type=TagType.KNOWLEDGE)
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@ -37,7 +37,8 @@ from fields.segment_fields import (
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.dataset_ref_service import DatasetRefService, SegmentRef
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
@ -127,6 +128,21 @@ register_response_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _get_segment_for_document(
|
||||
dataset: Dataset, document: Document, segment_id: str
|
||||
) -> tuple[SegmentRef, DocumentSegment]:
|
||||
dataset_ref = DatasetRefService.create_dataset_ref(dataset)
|
||||
document_ref = DatasetRefService.create_document_ref(dataset_ref, document)
|
||||
if document_ref is None:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_ref = DatasetRefService.create_segment_ref(document_ref, segment_id)
|
||||
segment = SegmentService.get_segment_by_ref(segment_ref)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
return segment_ref, segment
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class SegmentApi(DatasetApiResource):
|
||||
"""Resource for segments."""
|
||||
@ -337,7 +353,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
@ -353,10 +369,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return "", 204
|
||||
|
||||
@ -415,10 +428,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
@ -457,7 +467,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
service_api_ns.models[SegmentDetailResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
@ -473,10 +483,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
|
||||
response = {
|
||||
@ -538,10 +545,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
_, segment = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
@ -595,7 +599,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
service_api_ns.models[ChildChunkListResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
current_account_with_tenant()
|
||||
"""Get child chunks."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
@ -612,10 +616,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
_get_segment_for_document(dataset, document, segment_id_str)
|
||||
|
||||
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
|
||||
|
||||
@ -665,7 +666,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
current_account_with_tenant()
|
||||
"""Delete child chunk."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
@ -682,27 +683,14 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if segment.document_id != document_id_str:
|
||||
raise NotFound("Document not found.")
|
||||
segment_ref, _ = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
# check child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
||||
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
|
||||
)
|
||||
child_chunk = SegmentService.get_child_chunk_by_segment_ref(child_chunk_id_str, segment_ref)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate child chunk belongs to the specified segment
|
||||
if child_chunk.segment_id != segment.id:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
try:
|
||||
SegmentService.delete_child_chunk(child_chunk, dataset)
|
||||
except ChildChunkDeleteIndexServiceError as e:
|
||||
@ -739,7 +727,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
current_account_with_tenant()
|
||||
"""Update child chunk."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
@ -756,27 +744,14 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# get segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if segment.document_id != document_id_str:
|
||||
raise NotFound("Segment not found.")
|
||||
segment_ref, segment = _get_segment_for_document(dataset, document, segment_id_str)
|
||||
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
# get child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
||||
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
|
||||
)
|
||||
child_chunk = SegmentService.get_child_chunk_by_segment_ref(child_chunk_id_str, segment_ref)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate child chunk belongs to the specified segment
|
||||
if child_chunk.segment_id != segment.id:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate args
|
||||
payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
|
||||
@ -137,6 +137,7 @@ class TextApi(WebApiResource):
|
||||
voice=voice,
|
||||
end_user=end_user.external_user_id,
|
||||
message_id=message_id,
|
||||
message_end_user_id=end_user.id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -498,7 +498,11 @@ class LLMGenerator:
|
||||
ideal_output: str | None,
|
||||
):
|
||||
last_run: Message | None = db.session.scalar(
|
||||
select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1)
|
||||
select(Message)
|
||||
.join(App, App.id == Message.app_id)
|
||||
.where(Message.app_id == flow_id, App.tenant_id == tenant_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
if not last_run:
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
@ -540,7 +544,7 @@ class LLMGenerator:
|
||||
):
|
||||
session = db.session()
|
||||
|
||||
app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
|
||||
app: App | None = session.scalar(select(App).where(App.id == flow_id, App.tenant_id == tenant_id).limit(1))
|
||||
if not app:
|
||||
raise ValueError("App not found.")
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
|
||||
@ -45,52 +45,34 @@ def upgrade():
|
||||
# PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be
|
||||
# generated and controlled within the application layer.
|
||||
conn = op.get_bind()
|
||||
|
||||
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Create uuidv7 functions.
|
||||
# PostgreSQL 18 ships a native pg_catalog.uuidv7(), so only create our own
|
||||
# implementation when the server does not already provide one. Otherwise the
|
||||
# CREATE FUNCTION below and the unqualified COMMENT statement collide with the
|
||||
# built-in and the migration fails.
|
||||
#
|
||||
# The existence check is done server-side via a DO block rather than
|
||||
# conn.execute().scalar() because the latter returns None in offline
|
||||
# migration mode (no real database connection), causing an AttributeError.
|
||||
# PostgreSQL: Create uuidv7 functions
|
||||
op.execute(sa.text(r"""
|
||||
DO $do$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_proc p
|
||||
JOIN pg_namespace n ON p.pronamespace = n.oid
|
||||
WHERE p.proname = 'uuidv7' AND n.nspname = 'pg_catalog'
|
||||
) THEN
|
||||
/* Main function to generate a uuidv7 value with millisecond precision */
|
||||
CREATE FUNCTION public.uuidv7() RETURNS uuid
|
||||
AS
|
||||
$func$
|
||||
-- Replace the first 48 bits of a uuidv4 with the current
|
||||
-- number of milliseconds since 1970-01-01 UTC
|
||||
-- and set the "ver" field to 7 by setting additional bits
|
||||
SELECT encode(
|
||||
/* Main function to generate a uuidv7 value with millisecond precision */
|
||||
CREATE FUNCTION uuidv7() RETURNS uuid
|
||||
AS
|
||||
$$
|
||||
-- Replace the first 48 bits of a uuidv4 with the current
|
||||
-- number of milliseconds since 1970-01-01 UTC
|
||||
-- and set the "ver" field to 7 by setting additional bits
|
||||
SELECT encode(
|
||||
set_bit(
|
||||
set_bit(
|
||||
set_bit(
|
||||
overlay(uuid_send(gen_random_uuid()) placing
|
||||
substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from
|
||||
3)
|
||||
from 1 for 6),
|
||||
52, 1),
|
||||
53, 1), 'hex')::uuid;
|
||||
$func$ LANGUAGE SQL VOLATILE PARALLEL SAFE;
|
||||
overlay(uuid_send(gen_random_uuid()) placing
|
||||
substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from
|
||||
3)
|
||||
from 1 for 6),
|
||||
52, 1),
|
||||
53, 1), 'hex')::uuid;
|
||||
$$ LANGUAGE SQL VOLATILE PARALLEL SAFE;
|
||||
|
||||
COMMENT ON FUNCTION public.uuidv7 IS
|
||||
'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';
|
||||
END IF;
|
||||
END
|
||||
$do$;
|
||||
COMMENT ON FUNCTION uuidv7 IS
|
||||
'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';
|
||||
"""))
|
||||
|
||||
op.execute(sa.text(r"""
|
||||
CREATE FUNCTION public.uuidv7_boundary(timestamptz) RETURNS uuid
|
||||
CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid
|
||||
AS
|
||||
$$
|
||||
/* uuid fields: version=0b0111, variant=0b10 */
|
||||
@ -101,7 +83,7 @@ SELECT encode(
|
||||
'hex')::uuid;
|
||||
$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE;
|
||||
|
||||
COMMENT ON FUNCTION public.uuidv7_boundary(timestamptz) IS
|
||||
COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS
|
||||
'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.';
|
||||
"""
|
||||
))
|
||||
@ -113,10 +95,7 @@ def downgrade():
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
# IF EXISTS keeps the downgrade a no-op on PostgreSQL 18, where the native
|
||||
# pg_catalog.uuidv7() was kept and no public.uuidv7() was created. Scoping the
|
||||
# drop to the public schema avoids touching the built-in.
|
||||
op.execute(sa.text("DROP FUNCTION IF EXISTS public.uuidv7()"))
|
||||
op.execute(sa.text("DROP FUNCTION IF EXISTS public.uuidv7_boundary(timestamptz)"))
|
||||
op.execute(sa.text("DROP FUNCTION uuidv7"))
|
||||
op.execute(sa.text("DROP FUNCTION uuidv7_boundary"))
|
||||
else:
|
||||
pass
|
||||
|
||||
@ -14,6 +14,7 @@ from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.app_ref_service import AnnotationRef, AppRefService
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
@ -88,6 +89,17 @@ class UpdateAnnotationSettingArgs(TypedDict):
|
||||
|
||||
|
||||
class AppAnnotationService:
|
||||
@staticmethod
|
||||
def _get_annotation_by_ref(annotation_ref: AnnotationRef, session: scoped_session) -> MessageAnnotation | None:
|
||||
return session.scalar(
|
||||
select(MessageAnnotation)
|
||||
.where(
|
||||
MessageAnnotation.id == annotation_ref.annotation_id,
|
||||
MessageAnnotation.app_id == annotation_ref.app_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
@ -313,7 +325,9 @@ class AppAnnotationService:
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = session.get(MessageAnnotation, annotation_id)
|
||||
app_ref = AppRefService.create_app_ref(app)
|
||||
annotation_ref = AppRefService.create_annotation_ref(app_ref, annotation_id)
|
||||
annotation = cls._get_annotation_by_ref(annotation_ref, session)
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
@ -357,7 +371,9 @@ class AppAnnotationService:
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = session.get(MessageAnnotation, annotation_id)
|
||||
app_ref = AppRefService.create_app_ref(app)
|
||||
annotation_ref = AppRefService.create_annotation_ref(app_ref, annotation_id)
|
||||
annotation = cls._get_annotation_by_ref(annotation_ref, session)
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
@ -393,11 +409,13 @@ class AppAnnotationService:
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
app_ref = AppRefService.create_app_ref(app)
|
||||
|
||||
# Fetch annotations and their settings in a single query
|
||||
annotations_to_delete = db.session.execute(
|
||||
select(MessageAnnotation, AppAnnotationSetting)
|
||||
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
|
||||
.where(MessageAnnotation.id.in_(annotation_ids))
|
||||
.where(MessageAnnotation.id.in_(annotation_ids), MessageAnnotation.app_id == app_ref.app_id)
|
||||
).all()
|
||||
|
||||
if not annotations_to_delete:
|
||||
@ -420,7 +438,10 @@ class AppAnnotationService:
|
||||
|
||||
# Step 4: Bulk delete annotations in a single query
|
||||
delete_result = db.session.execute(
|
||||
delete(MessageAnnotation).where(MessageAnnotation.id.in_(annotation_ids_to_delete))
|
||||
delete(MessageAnnotation).where(
|
||||
MessageAnnotation.id.in_(annotation_ids_to_delete),
|
||||
MessageAnnotation.app_id == app_ref.app_id,
|
||||
)
|
||||
)
|
||||
deleted_count = getattr(delete_result, "rowcount", 0)
|
||||
|
||||
@ -572,7 +593,9 @@ class AppAnnotationService:
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.get(MessageAnnotation, annotation_id)
|
||||
app_ref = AppRefService.create_app_ref(app)
|
||||
annotation_ref = AppRefService.create_annotation_ref(app_ref, annotation_id)
|
||||
annotation = cls._get_annotation_by_ref(annotation_ref, db.session)
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
100
api/services/app_ref_service.py
Normal file
100
api/services/app_ref_service.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""Typed resource references for app ownership chains."""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from models.model import App
|
||||
|
||||
_APP_REF_CTOR_TOKEN = object()
|
||||
|
||||
|
||||
class _AppRefBase(NamedTuple):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
ctor_token: object
|
||||
|
||||
|
||||
class AppRef(_AppRefBase):
|
||||
"""Tenant-scoped app reference with token-gated construction."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, tenant_id: str, app_id: str, ctor_token: object) -> "AppRef":
|
||||
if ctor_token is not _APP_REF_CTOR_TOKEN:
|
||||
raise ValueError("AppRef must be created by AppRefService.")
|
||||
return super().__new__(cls, tenant_id, app_id, ctor_token)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"AppRef(tenant_id={self.tenant_id!r}, app_id={self.app_id!r})"
|
||||
|
||||
|
||||
class MessageRef(NamedTuple):
|
||||
"""Message reference bound to a trusted app reference."""
|
||||
|
||||
app: AppRef
|
||||
message_id: str
|
||||
end_user_id: str | None = None
|
||||
account_id: str | None = None
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self.app.tenant_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str:
|
||||
return self.app.app_id
|
||||
|
||||
|
||||
class AnnotationRef(NamedTuple):
|
||||
"""Annotation reference bound to a trusted app reference."""
|
||||
|
||||
app: AppRef
|
||||
annotation_id: str
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self.app.tenant_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str:
|
||||
return self.app.app_id
|
||||
|
||||
|
||||
class AppMCPServerRef(NamedTuple):
|
||||
"""MCP server reference bound to a trusted app reference."""
|
||||
|
||||
app: AppRef
|
||||
server_id: str
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self.app.tenant_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str:
|
||||
return self.app.app_id
|
||||
|
||||
|
||||
class AppRefService:
|
||||
"""Factory for trusted app and child resource refs."""
|
||||
|
||||
@staticmethod
|
||||
def create_app_ref(app: App) -> AppRef:
|
||||
return AppRef(app.tenant_id, app.id, _APP_REF_CTOR_TOKEN)
|
||||
|
||||
@staticmethod
|
||||
def create_message_ref(
|
||||
app_ref: AppRef,
|
||||
message_id: str,
|
||||
*,
|
||||
end_user_id: str | None = None,
|
||||
account_id: str | None = None,
|
||||
) -> MessageRef:
|
||||
return MessageRef(app=app_ref, message_id=message_id, end_user_id=end_user_id, account_id=account_id)
|
||||
|
||||
@staticmethod
|
||||
def create_annotation_ref(app_ref: AppRef, annotation_id: str) -> AnnotationRef:
|
||||
return AnnotationRef(app=app_ref, annotation_id=annotation_id)
|
||||
|
||||
@staticmethod
|
||||
def create_mcp_server_ref(app_ref: AppRef, server_id: str) -> AppMCPServerRef:
|
||||
return AppMCPServerRef(app=app_ref, server_id=server_id)
|
||||
@ -5,6 +5,7 @@ from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
@ -13,6 +14,7 @@ from core.model_manager import ModelManager
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.enums import MessageStatus
|
||||
from models.model import App, AppMode, Message
|
||||
from services.app_ref_service import AppRefService, MessageRef
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
@ -29,6 +31,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioService:
|
||||
@staticmethod
|
||||
def _get_message_by_ref(session: Session | scoped_session, message_ref: MessageRef) -> Message | None:
|
||||
stmt = select(Message).where(Message.id == message_ref.message_id, Message.app_id == message_ref.app_id)
|
||||
if message_ref.end_user_id is not None:
|
||||
stmt = stmt.where(Message.from_end_user_id == message_ref.end_user_id)
|
||||
if message_ref.account_id is not None:
|
||||
stmt = stmt.where(Message.from_account_id == message_ref.account_id)
|
||||
return session.scalar(stmt.limit(1))
|
||||
|
||||
@classmethod
|
||||
def transcript_asr(cls, app_model: App, file: FileStorage | None, end_user: str | None = None):
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
@ -83,6 +94,8 @@ class AudioService:
|
||||
voice: str | None = None,
|
||||
end_user: str | None = None,
|
||||
message_id: str | None = None,
|
||||
message_end_user_id: str | None = None,
|
||||
message_account_id: str | None = None,
|
||||
is_draft: bool = False,
|
||||
):
|
||||
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
|
||||
@ -134,7 +147,14 @@ class AudioService:
|
||||
uuid.UUID(message_id)
|
||||
except ValueError:
|
||||
return None
|
||||
message = session.get(Message, message_id)
|
||||
app_ref = AppRefService.create_app_ref(app_model)
|
||||
message_ref = AppRefService.create_message_ref(
|
||||
app_ref,
|
||||
message_id,
|
||||
end_user_id=message_end_user_id,
|
||||
account_id=message_account_id,
|
||||
)
|
||||
message = cls._get_message_by_ref(session, message_ref)
|
||||
if message is None:
|
||||
return None
|
||||
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:
|
||||
|
||||
85
api/services/dataset_ref_service.py
Normal file
85
api/services/dataset_ref_service.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""Typed resource references for dataset ownership chains.
|
||||
|
||||
Controllers and other trust-boundary code should build these refs after
|
||||
resolving and authorizing the outer resource. Downstream service queries can
|
||||
then require the full tenant -> dataset -> document -> segment chain instead of
|
||||
receiving loosely related raw ids.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from models.dataset import Dataset, Document
|
||||
|
||||
_DATASET_REF_CTOR_TOKEN = object()
|
||||
|
||||
|
||||
class _DatasetRefBase(NamedTuple):
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
ctor_token: object
|
||||
|
||||
|
||||
class DatasetRef(_DatasetRefBase):
|
||||
"""Tenant-scoped dataset reference with token-gated construction."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, tenant_id: str, dataset_id: str, ctor_token: object) -> "DatasetRef":
|
||||
if ctor_token is not _DATASET_REF_CTOR_TOKEN:
|
||||
raise ValueError("DatasetRef must be created by DatasetRefService.")
|
||||
return super().__new__(cls, tenant_id, dataset_id, ctor_token)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DatasetRef(tenant_id={self.tenant_id!r}, dataset_id={self.dataset_id!r})"
|
||||
|
||||
|
||||
class DocumentRef(NamedTuple):
|
||||
"""Document reference bound to a trusted dataset reference."""
|
||||
|
||||
dataset: DatasetRef
|
||||
document_id: str
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self.dataset.tenant_id
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str:
|
||||
return self.dataset.dataset_id
|
||||
|
||||
|
||||
class SegmentRef(NamedTuple):
|
||||
"""Segment reference bound to a trusted document reference."""
|
||||
|
||||
document: DocumentRef
|
||||
segment_id: str
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self.document.tenant_id
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str:
|
||||
return self.document.dataset_id
|
||||
|
||||
@property
|
||||
def document_id(self) -> str:
|
||||
return self.document.document_id
|
||||
|
||||
|
||||
class DatasetRefService:
|
||||
"""Factory for trusted dataset, document, and segment refs."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_ref(dataset: Dataset) -> DatasetRef:
|
||||
return DatasetRef(dataset.tenant_id, dataset.id, _DATASET_REF_CTOR_TOKEN)
|
||||
|
||||
@staticmethod
|
||||
def create_document_ref(dataset_ref: DatasetRef, document: Document) -> DocumentRef | None:
|
||||
if document.tenant_id != dataset_ref.tenant_id or document.dataset_id != dataset_ref.dataset_id:
|
||||
return None
|
||||
return DocumentRef(dataset=dataset_ref, document_id=document.id)
|
||||
|
||||
@staticmethod
|
||||
def create_segment_ref(document_ref: DocumentRef, segment_id: str) -> SegmentRef:
|
||||
return SegmentRef(document=document_ref, segment_id=segment_id)
|
||||
@ -65,6 +65,7 @@ from models.model import UploadFile
|
||||
from models.provider_ids import ModelProviderID
|
||||
from models.source import DataSourceOauthBinding
|
||||
from models.workflow import Workflow
|
||||
from services.dataset_ref_service import DatasetRefService, SegmentRef
|
||||
from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy
|
||||
from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
@ -1926,7 +1927,15 @@ class DocumentService:
|
||||
# Check if document_ids is not empty to avoid WHERE false condition
|
||||
if not document_ids or len(document_ids) == 0:
|
||||
return
|
||||
documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all()
|
||||
dataset_ref = DatasetRefService.create_dataset_ref(dataset)
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.id.in_(document_ids),
|
||||
Document.tenant_id == dataset_ref.tenant_id,
|
||||
Document.dataset_id == dataset_ref.dataset_id,
|
||||
)
|
||||
).all()
|
||||
deleted_document_ids = [document.id for document in documents]
|
||||
file_ids = [
|
||||
document.data_source_info_dict.get("upload_file_id", "")
|
||||
for document in documents
|
||||
@ -1941,8 +1950,8 @@ class DocumentService:
|
||||
|
||||
# Dispatch cleanup task after commit to avoid lock contention
|
||||
# Task cleans up segments, files, and vector indexes
|
||||
if dataset.doc_form is not None:
|
||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
if deleted_document_ids and dataset.doc_form is not None:
|
||||
batch_clean_document_task.delay(deleted_document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
|
||||
@staticmethod
|
||||
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
|
||||
@ -4028,6 +4037,22 @@ class SegmentService:
|
||||
)
|
||||
return result if isinstance(result, ChildChunk) else None
|
||||
|
||||
@classmethod
|
||||
def get_child_chunk_by_segment_ref(cls, child_chunk_id: str, segment_ref: SegmentRef) -> ChildChunk | None:
|
||||
"""Get a child chunk through the full tenant/dataset/document/segment chain."""
|
||||
result = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == child_chunk_id,
|
||||
ChildChunk.tenant_id == segment_ref.tenant_id,
|
||||
ChildChunk.dataset_id == segment_ref.dataset_id,
|
||||
ChildChunk.document_id == segment_ref.document_id,
|
||||
ChildChunk.segment_id == segment_ref.segment_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return result if isinstance(result, ChildChunk) else None
|
||||
|
||||
@classmethod
|
||||
def get_segments(
|
||||
cls,
|
||||
@ -4066,6 +4091,21 @@ class SegmentService:
|
||||
)
|
||||
return result if isinstance(result, DocumentSegment) else None
|
||||
|
||||
@classmethod
|
||||
def get_segment_by_ref(cls, segment_ref: SegmentRef) -> DocumentSegment | None:
|
||||
"""Get a segment through the full tenant/dataset/document ownership chain."""
|
||||
result = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.id == segment_ref.segment_id,
|
||||
DocumentSegment.tenant_id == segment_ref.tenant_id,
|
||||
DocumentSegment.dataset_id == segment_ref.dataset_id,
|
||||
DocumentSegment.document_id == segment_ref.document_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return result if isinstance(result, DocumentSegment) else None
|
||||
|
||||
@classmethod
|
||||
def get_segments_by_document_and_dataset(
|
||||
cls,
|
||||
|
||||
@ -82,6 +82,7 @@ from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError,
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader
|
||||
from services.workflow_ref_service import WorkflowRef
|
||||
from services.workflow_restore import apply_published_workflow_snapshot_to_draft
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -984,8 +985,21 @@ class RagPipelineService:
|
||||
if invoke_from:
|
||||
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
|
||||
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
|
||||
if document_id:
|
||||
document = db.session.get(Document, document_id.value)
|
||||
dataset_id = get_system_segment(variable_pool, SystemVariableKey.DATASET_ID)
|
||||
pipeline_id = get_system_segment(variable_pool, SystemVariableKey.APP_ID)
|
||||
if document_id and dataset_id and pipeline_id:
|
||||
document = db.session.scalar(
|
||||
select(Document)
|
||||
.join(Dataset, Dataset.id == Document.dataset_id)
|
||||
.where(
|
||||
Document.id == document_id.value,
|
||||
Document.tenant_id == tenant_id,
|
||||
Document.dataset_id == dataset_id.value,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.pipeline_id == pipeline_id.value,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = error
|
||||
@ -995,7 +1009,14 @@ class RagPipelineService:
|
||||
return workflow_node_execution
|
||||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any]
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_id: str,
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
data: dict[str, Any],
|
||||
workflow_ref: WorkflowRef | None = None,
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Update workflow attributes
|
||||
@ -1005,9 +1026,14 @@ class RagPipelineService:
|
||||
:param tenant_id: Tenant ID
|
||||
:param account_id: Account ID (for permission check)
|
||||
:param data: Dictionary containing fields to update
|
||||
:param workflow_ref: Optional trusted owner-bound workflow reference
|
||||
:return: Updated workflow or None if not found
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
lookup_workflow_id = workflow_ref.workflow_id if workflow_ref else workflow_id
|
||||
lookup_tenant_id = workflow_ref.tenant_id if workflow_ref else tenant_id
|
||||
stmt = select(Workflow).where(Workflow.id == lookup_workflow_id, Workflow.tenant_id == lookup_tenant_id)
|
||||
if workflow_ref is not None:
|
||||
stmt = stmt.where(Workflow.app_id == workflow_ref.app_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
|
||||
@ -147,8 +147,14 @@ class TagService:
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def update_tags(payload: UpdateTagPayload, tag_id: str, session: scoped_session) -> Tag:
|
||||
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
def update_tags(
|
||||
payload: UpdateTagPayload, tag_id: str, session: scoped_session, *, tag_type: TagType | None = None
|
||||
) -> Tag:
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
stmt = select(Tag).where(Tag.id == tag_id, Tag.tenant_id == current_tenant_id)
|
||||
if tag_type is not None:
|
||||
stmt = stmt.where(Tag.type == tag_type)
|
||||
tag = session.scalar(stmt.limit(1))
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
if payload.name != tag.name:
|
||||
@ -169,18 +175,32 @@ class TagService:
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def get_tag_binding_count(tag_id: str, session: scoped_session) -> int:
|
||||
count = session.scalar(select(func.count(TagBinding.id)).where(TagBinding.tag_id == tag_id)) or 0
|
||||
def get_tag_binding_count(tag_id: str, session: scoped_session, *, tag_type: TagType | None = None) -> int:
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
stmt = (
|
||||
select(func.count(TagBinding.id))
|
||||
.join(Tag, Tag.id == TagBinding.tag_id)
|
||||
.where(TagBinding.tag_id == tag_id, Tag.tenant_id == current_tenant_id)
|
||||
)
|
||||
if tag_type is not None:
|
||||
stmt = stmt.where(Tag.type == tag_type)
|
||||
count = session.scalar(stmt) or 0
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def delete_tag(tag_id: str, session: scoped_session):
|
||||
tag = session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
def delete_tag(tag_id: str, session: scoped_session, *, tag_type: TagType | None = None):
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
stmt = select(Tag).where(Tag.id == tag_id, Tag.tenant_id == current_tenant_id)
|
||||
if tag_type is not None:
|
||||
stmt = stmt.where(Tag.type == tag_type)
|
||||
tag = session.scalar(stmt.limit(1))
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
session.delete(tag)
|
||||
# delete tag binding
|
||||
tag_bindings = session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
|
||||
tag_bindings = session.scalars(
|
||||
select(TagBinding).where(TagBinding.tag_id == tag_id, TagBinding.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
if tag_bindings:
|
||||
for tag_binding in tag_bindings:
|
||||
session.delete(tag_binding)
|
||||
|
||||
57
api/services/workflow_ref_service.py
Normal file
57
api/services/workflow_ref_service.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Typed resource references for workflow ownership chains."""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from models.dataset import Pipeline
|
||||
from models.model import App
|
||||
|
||||
_WORKFLOW_OWNER_REF_CTOR_TOKEN = object()
|
||||
|
||||
|
||||
class _WorkflowOwnerRefBase(NamedTuple):
|
||||
tenant_id: str
|
||||
owner_id: str
|
||||
ctor_token: object
|
||||
|
||||
|
||||
class WorkflowOwnerRef(_WorkflowOwnerRefBase):
|
||||
"""Tenant-scoped workflow owner reference with token-gated construction."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, tenant_id: str, owner_id: str, ctor_token: object) -> "WorkflowOwnerRef":
|
||||
if ctor_token is not _WORKFLOW_OWNER_REF_CTOR_TOKEN:
|
||||
raise ValueError("WorkflowOwnerRef must be created by WorkflowRefService.")
|
||||
return super().__new__(cls, tenant_id, owner_id, ctor_token)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"WorkflowOwnerRef(tenant_id={self.tenant_id!r}, owner_id={self.owner_id!r})"
|
||||
|
||||
|
||||
class WorkflowRef(NamedTuple):
|
||||
"""Workflow reference bound to a trusted owner reference."""
|
||||
|
||||
owner: WorkflowOwnerRef
|
||||
workflow_id: str
|
||||
|
||||
@property
|
||||
def tenant_id(self) -> str:
|
||||
return self.owner.tenant_id
|
||||
|
||||
@property
|
||||
def app_id(self) -> str:
|
||||
return self.owner.owner_id
|
||||
|
||||
|
||||
class WorkflowRefService:
|
||||
"""Factory for trusted app and RAG pipeline workflow refs."""
|
||||
|
||||
@staticmethod
|
||||
def create_app_workflow_ref(app: App, workflow_id: str) -> WorkflowRef:
|
||||
owner = WorkflowOwnerRef(app.tenant_id, app.id, _WORKFLOW_OWNER_REF_CTOR_TOKEN)
|
||||
return WorkflowRef(owner=owner, workflow_id=workflow_id)
|
||||
|
||||
@staticmethod
|
||||
def create_pipeline_workflow_ref(pipeline: Pipeline, workflow_id: str) -> WorkflowRef:
|
||||
owner = WorkflowOwnerRef(pipeline.tenant_id, pipeline.id, _WORKFLOW_OWNER_REF_CTOR_TOKEN)
|
||||
return WorkflowRef(owner=owner, workflow_id=workflow_id)
|
||||
@ -84,6 +84,7 @@ from services.errors.app import (
|
||||
)
|
||||
from services.human_input_service import HumanInputService
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
from services.workflow_ref_service import WorkflowRef
|
||||
|
||||
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
from .human_input_delivery_test_service import (
|
||||
@ -1593,7 +1594,14 @@ class WorkflowService:
|
||||
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
|
||||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any]
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_id: str,
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
data: dict[str, Any],
|
||||
workflow_ref: WorkflowRef | None = None,
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Update workflow attributes
|
||||
@ -1603,9 +1611,14 @@ class WorkflowService:
|
||||
:param tenant_id: Tenant ID
|
||||
:param account_id: Account ID (for permission check)
|
||||
:param data: Dictionary containing fields to update
|
||||
:param workflow_ref: Optional trusted owner-bound workflow reference
|
||||
:return: Updated workflow or None if not found
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
lookup_workflow_id = workflow_ref.workflow_id if workflow_ref else workflow_id
|
||||
lookup_tenant_id = workflow_ref.tenant_id if workflow_ref else tenant_id
|
||||
stmt = select(Workflow).where(Workflow.id == lookup_workflow_id, Workflow.tenant_id == lookup_tenant_id)
|
||||
if workflow_ref is not None:
|
||||
stmt = stmt.where(Workflow.app_id == workflow_ref.app_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
@ -1622,30 +1635,37 @@ class WorkflowService:
|
||||
|
||||
return workflow
|
||||
|
||||
def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool:
|
||||
def delete_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, workflow_ref: WorkflowRef | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a workflow
|
||||
|
||||
:param session: SQLAlchemy database session
|
||||
:param workflow_id: Workflow ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param workflow_ref: Optional trusted owner-bound workflow reference
|
||||
:return: True if successful
|
||||
:raises: ValueError if workflow not found
|
||||
:raises: WorkflowInUseError if workflow is in use
|
||||
:raises: DraftWorkflowDeletionError if workflow is a draft version
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
lookup_workflow_id = workflow_ref.workflow_id if workflow_ref else workflow_id
|
||||
lookup_tenant_id = workflow_ref.tenant_id if workflow_ref else tenant_id
|
||||
stmt = select(Workflow).where(Workflow.id == lookup_workflow_id, Workflow.tenant_id == lookup_tenant_id)
|
||||
if workflow_ref is not None:
|
||||
stmt = stmt.where(Workflow.app_id == workflow_ref.app_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
raise ValueError(f"Workflow with ID {lookup_workflow_id} not found")
|
||||
|
||||
# Check if workflow is a draft version
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
|
||||
|
||||
# Check if this workflow is currently referenced by an app
|
||||
app_stmt = select(App).where(App.workflow_id == workflow_id)
|
||||
app_stmt = select(App).where(App.workflow_id == lookup_workflow_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if app:
|
||||
# Cannot delete a workflow that's currently in use by an app
|
||||
|
||||
@ -70,7 +70,7 @@ def test_instruction_generate_app_not_found(app: Flask, monkeypatch: pytest.Monk
|
||||
method = unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
@ -86,7 +86,13 @@ def test_instruction_generate_app_not_found(app: Flask, monkeypatch: pytest.Monk
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "app app-1 not found"
|
||||
session.get.assert_called_once_with(generator_module.App, "app-1")
|
||||
stmt = session.scalar.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "apps.id" in statement
|
||||
assert "apps.tenant_id" in statement
|
||||
assert "app-1" in compiled.params.values()
|
||||
assert "t1" in compiled.params.values()
|
||||
|
||||
|
||||
def test_instruction_generate_workflow_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -94,7 +100,7 @@ def test_instruction_generate_workflow_not_found(app: Flask, monkeypatch: pytest
|
||||
method = unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
|
||||
session = SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)
|
||||
_install_workflow_service(monkeypatch, workflow=None)
|
||||
|
||||
with app.test_request_context(
|
||||
@ -118,7 +124,7 @@ def test_instruction_generate_node_missing(app: Flask, monkeypatch: pytest.Monke
|
||||
method = unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
|
||||
session = SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)
|
||||
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
@ -144,7 +150,7 @@ def test_instruction_generate_code_node(app: Flask, monkeypatch: pytest.MonkeyPa
|
||||
method = unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
|
||||
session = SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
|
||||
@ -130,3 +130,50 @@ class TestAppMCPServerController:
|
||||
|
||||
assert response == {"id": "server-1"}
|
||||
assert status_code == 201
|
||||
|
||||
def test_put_binds_server_lookup_to_app_ref(self):
|
||||
api = AppMCPServerController()
|
||||
method = unwrap(api.put)
|
||||
payload = {"id": "server-1", "description": "Updated", "parameters": {"timeout": 30}, "status": "active"}
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
server = SimpleNamespace(
|
||||
id="server-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
name="Old",
|
||||
description="Old",
|
||||
parameters="{}",
|
||||
status="active",
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch("controllers.console.app.mcp_server.db.session.scalar", return_value=server) as scalar,
|
||||
patch("controllers.console.app.mcp_server.db.session.get") as get_mock,
|
||||
patch("controllers.console.app.mcp_server.db.session.commit") as commit,
|
||||
patch(
|
||||
"controllers.console.app.mcp_server.AppMCPServerResponse.model_validate",
|
||||
return_value=_ValidatedResponse({"id": "server-1"}),
|
||||
),
|
||||
):
|
||||
response = method(
|
||||
api,
|
||||
app_model=SimpleNamespace(
|
||||
id="app-1", tenant_id="tenant-1", name="Demo App", description="App description"
|
||||
),
|
||||
)
|
||||
|
||||
stmt = scalar.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "app_mcp_servers.id" in statement
|
||||
assert "app_mcp_servers.tenant_id" in statement
|
||||
assert "app_mcp_servers.app_id" in statement
|
||||
assert payload["id"] in compiled.params.values()
|
||||
assert "tenant-1" in compiled.params.values()
|
||||
assert "app-1" in compiled.params.values()
|
||||
get_mock.assert_not_called()
|
||||
commit.assert_called_once()
|
||||
assert response == {"id": "server-1"}
|
||||
|
||||
@ -108,6 +108,15 @@ def _segment_response_dict():
|
||||
}
|
||||
|
||||
|
||||
def _bind_dataset_document(dataset, document, dataset_id: str = "ds-1", document_id: str = "doc-1"):
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = "tenant-1"
|
||||
document.id = document_id
|
||||
document.dataset_id = dataset_id
|
||||
document.tenant_id = "tenant-1"
|
||||
return document
|
||||
|
||||
|
||||
def test_segment_response_with_summary():
|
||||
segment = _segment()
|
||||
|
||||
@ -380,6 +389,7 @@ class TestDatasetDocumentSegmentAddApi:
|
||||
|
||||
document = MagicMock()
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
_bind_dataset_document(dataset, document)
|
||||
|
||||
segment = _segment()
|
||||
|
||||
@ -504,6 +514,7 @@ class TestDatasetDocumentSegmentUpdateApi:
|
||||
|
||||
document = MagicMock()
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
_bind_dataset_document(dataset, document)
|
||||
|
||||
segment = _segment()
|
||||
|
||||
@ -519,8 +530,8 @@ class TestDatasetDocumentSegmentUpdateApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
side_effect=[segment, None],
|
||||
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -538,6 +549,7 @@ class TestDatasetDocumentSegmentUpdateApi:
|
||||
"controllers.console.datasets.datasets_segments.SummaryIndexService.get_segment_summary",
|
||||
return_value=None,
|
||||
),
|
||||
patch("models.dataset.db.session.scalar", return_value=None),
|
||||
patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))),
|
||||
):
|
||||
response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1")
|
||||
@ -576,6 +588,10 @@ class TestDatasetDocumentSegmentUpdateApi:
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.ModelManager.get_model_instance",
|
||||
side_effect=LLMBadRequestError(),
|
||||
@ -781,13 +797,15 @@ class TestChildChunkAddApi:
|
||||
api = ChildChunkAddApi()
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
dataset = MagicMock()
|
||||
document = _bind_dataset_document(dataset, MagicMock())
|
||||
pagination = MagicMock(items=[], total=0, pages=0)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=bad&limit="),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting",
|
||||
@ -795,10 +813,10 @@ class TestChildChunkAddApi:
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DocumentService.get_document",
|
||||
return_value=MagicMock(),
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
@ -826,6 +844,7 @@ class TestChildChunkAddApi:
|
||||
dataset.indexing_technique = "economy"
|
||||
|
||||
document = MagicMock()
|
||||
_bind_dataset_document(dataset, document)
|
||||
segment = MagicMock()
|
||||
child_chunk = _child_chunk()
|
||||
|
||||
@ -841,7 +860,7 @@ class TestChildChunkAddApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
@ -868,6 +887,7 @@ class TestChildChunkAddApi:
|
||||
|
||||
dataset = MagicMock(indexing_technique="economy")
|
||||
document = MagicMock()
|
||||
_bind_dataset_document(dataset, document)
|
||||
segment = MagicMock()
|
||||
|
||||
with (
|
||||
@ -882,7 +902,7 @@ class TestChildChunkAddApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
@ -908,6 +928,7 @@ class TestChildChunkUpdateApi:
|
||||
|
||||
dataset = MagicMock()
|
||||
document = MagicMock()
|
||||
_bind_dataset_document(dataset, document)
|
||||
segment = MagicMock()
|
||||
child_chunk = MagicMock()
|
||||
|
||||
@ -922,8 +943,12 @@ class TestChildChunkUpdateApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
side_effect=[segment, child_chunk],
|
||||
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.SegmentService.get_child_chunk_by_segment_ref",
|
||||
return_value=child_chunk,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -947,6 +972,7 @@ class TestChildChunkUpdateApi:
|
||||
|
||||
dataset = MagicMock()
|
||||
document = MagicMock()
|
||||
_bind_dataset_document(dataset, document)
|
||||
segment = MagicMock()
|
||||
child_chunk = MagicMock()
|
||||
|
||||
@ -961,8 +987,12 @@ class TestChildChunkUpdateApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
side_effect=[segment, child_chunk],
|
||||
"controllers.console.datasets.datasets_segments.SegmentService.get_segment_by_ref",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.SegmentService.get_child_chunk_by_segment_ref",
|
||||
return_value=child_chunk,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
|
||||
@ -245,7 +245,7 @@ class TestTextApi:
|
||||
api = TextApi()
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(external_user_id="ext")
|
||||
end_user = SimpleNamespace(id="end-user-1", external_user_id="ext")
|
||||
|
||||
with app.test_request_context(
|
||||
"/text-to-audio",
|
||||
@ -264,7 +264,7 @@ class TestTextApi:
|
||||
api = TextApi()
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(external_user_id="ext")
|
||||
end_user = SimpleNamespace(id="end-user-1", external_user_id="ext")
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST", json={"text": "hello"}):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
|
||||
@ -90,6 +90,19 @@ def _child_chunk() -> ChildChunk:
|
||||
return child_chunk
|
||||
|
||||
|
||||
def _document_for_dataset(
|
||||
dataset: Dataset, document_id: str = "doc-id", doc_form: str = IndexStructureType.PARAGRAPH_INDEX
|
||||
):
|
||||
document = Mock()
|
||||
document.id = document_id
|
||||
document.dataset_id = dataset.id
|
||||
document.tenant_id = dataset.tenant_id
|
||||
document.indexing_status = "completed"
|
||||
document.enabled = True
|
||||
document.doc_form = doc_form
|
||||
return document
|
||||
|
||||
|
||||
class TestSegmentCreatePayload:
|
||||
"""Test suite for SegmentCreatePayload Pydantic model."""
|
||||
|
||||
@ -868,7 +881,9 @@ class TestSegmentApiGet:
|
||||
# Arrange
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(
|
||||
mock_dataset, doc_form=IndexStructureType.PARAGRAPH_INDEX
|
||||
)
|
||||
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
|
||||
mock_get_summaries.return_value = {}
|
||||
mock_dump_segments.return_value = [_segment_response_dict()]
|
||||
@ -988,7 +1003,7 @@ class TestSegmentApiPost:
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc = _document_for_dataset(mock_dataset)
|
||||
mock_doc.indexing_status = "completed"
|
||||
mock_doc.enabled = True
|
||||
mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
@ -1040,7 +1055,7 @@ class TestSegmentApiPost:
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc = _document_for_dataset(mock_dataset)
|
||||
mock_doc.indexing_status = "completed"
|
||||
mock_doc.enabled = True
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
@ -1082,7 +1097,7 @@ class TestSegmentApiPost:
|
||||
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc = _document_for_dataset(mock_dataset)
|
||||
mock_doc.indexing_status = "indexing" # Not completed
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
|
||||
@ -1134,10 +1149,10 @@ class TestDatasetSegmentApiDelete:
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc = _document_for_dataset(mock_dataset)
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
|
||||
mock_seg_svc.delete_segment.return_value = None
|
||||
|
||||
# Act
|
||||
@ -1177,13 +1192,13 @@ class TestDatasetSegmentApiDelete:
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc = _document_for_dataset(mock_dataset)
|
||||
mock_doc.indexing_status = "completed"
|
||||
mock_doc.enabled = True
|
||||
mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
|
||||
mock_seg_svc.get_segment_by_id.return_value = None # Segment not found
|
||||
mock_seg_svc.get_segment_by_ref.return_value = None # Segment not found
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
@ -1329,8 +1344,10 @@ class TestDatasetSegmentApiUpdate:
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(
|
||||
mock_dataset, doc_form=IndexStructureType.PARAGRAPH_INDEX
|
||||
)
|
||||
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
|
||||
updated = Mock()
|
||||
updated.id = "updated-seg"
|
||||
mock_seg_svc.update_segment.return_value = updated
|
||||
@ -1419,8 +1436,8 @@ class TestDatasetSegmentApiUpdate:
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = None
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
|
||||
mock_seg_svc.get_segment_by_ref.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id",
|
||||
@ -1470,9 +1487,9 @@ class TestDatasetSegmentApiGetSingle:
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_doc = _document_for_dataset(mock_dataset, doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
|
||||
mock_get_summary.return_value = None
|
||||
mock_dump_segment.return_value = _segment_response_dict()
|
||||
|
||||
@ -1517,9 +1534,9 @@ class TestDatasetSegmentApiGetSingle:
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_doc = _document_for_dataset(mock_dataset, doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
|
||||
mock_summary_record = Mock(summary_content="This is the segment summary")
|
||||
mock_get_summary.return_value = mock_summary_record
|
||||
mock_dump_segment.return_value = _segment_response_dict("This is the segment summary")
|
||||
@ -1619,8 +1636,8 @@ class TestDatasetSegmentApiGetSingle:
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = None
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
|
||||
mock_seg_svc.get_segment_by_ref.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id",
|
||||
@ -1660,8 +1677,8 @@ class TestChildChunkApiGet:
|
||||
"""Test successful child chunk list retrieval."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = Mock()
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
|
||||
mock_seg_svc.get_segment_by_ref.return_value = Mock()
|
||||
|
||||
mock_pagination = Mock()
|
||||
mock_pagination.items = [_child_chunk(), _child_chunk()]
|
||||
@ -1759,8 +1776,8 @@ class TestChildChunkApiGet:
|
||||
"""Test 404 when segment not found."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = None
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
|
||||
mock_seg_svc.get_segment_by_ref.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks",
|
||||
@ -1822,8 +1839,8 @@ class TestChildChunkApiPost:
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_dataset.indexing_technique = "economy"
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = Mock()
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
|
||||
mock_seg_svc.get_segment_by_ref.return_value = Mock()
|
||||
mock_child = _child_chunk()
|
||||
mock_seg_svc.create_child_chunk.return_value = mock_child
|
||||
|
||||
@ -1900,8 +1917,8 @@ class TestChildChunkApiPost:
|
||||
self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id)
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_seg_svc.get_segment_by_id.return_value = None
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
|
||||
mock_seg_svc.get_segment_by_ref.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks",
|
||||
@ -1954,19 +1971,19 @@ class TestDatasetChildChunkApiDelete:
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc = _document_for_dataset(mock_dataset)
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
|
||||
segment_id = str(uuid.uuid4())
|
||||
mock_segment = Mock()
|
||||
mock_segment.id = segment_id
|
||||
mock_segment.document_id = "doc-id"
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
|
||||
|
||||
child_chunk_id = str(uuid.uuid4())
|
||||
mock_child = Mock()
|
||||
mock_child.segment_id = segment_id
|
||||
mock_seg_svc.get_child_chunk_by_id.return_value = mock_child
|
||||
mock_seg_svc.get_child_chunk_by_segment_ref.return_value = mock_child
|
||||
mock_seg_svc.delete_child_chunk.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
@ -2003,14 +2020,14 @@ class TestDatasetChildChunkApiDelete:
|
||||
"""Test 404 when child chunk not found."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
|
||||
|
||||
segment_id = str(uuid.uuid4())
|
||||
mock_segment = Mock()
|
||||
mock_segment.id = segment_id
|
||||
mock_segment.document_id = "doc-id"
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_seg_svc.get_child_chunk_by_id.return_value = None
|
||||
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
|
||||
mock_seg_svc.get_child_chunk_by_segment_ref.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{segment_id}/child_chunks/cc-id",
|
||||
@ -2044,13 +2061,10 @@ class TestDatasetChildChunkApiDelete:
|
||||
"""Test 404 when segment does not belong to the document."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
|
||||
|
||||
segment_id = str(uuid.uuid4())
|
||||
mock_segment = Mock()
|
||||
mock_segment.id = segment_id
|
||||
mock_segment.document_id = "different-doc-id"
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_seg_svc.get_segment_by_ref.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{segment_id}/child_chunks/cc-id",
|
||||
@ -2084,17 +2098,15 @@ class TestDatasetChildChunkApiDelete:
|
||||
"""Test 404 when child chunk does not belong to the segment."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock()
|
||||
mock_doc_svc.get_document.return_value = _document_for_dataset(mock_dataset)
|
||||
|
||||
segment_id = str(uuid.uuid4())
|
||||
mock_segment = Mock()
|
||||
mock_segment.id = segment_id
|
||||
mock_segment.document_id = "doc-id"
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_seg_svc.get_segment_by_ref.return_value = mock_segment
|
||||
|
||||
mock_child = Mock()
|
||||
mock_child.segment_id = "different-segment-id"
|
||||
mock_seg_svc.get_child_chunk_by_id.return_value = mock_child
|
||||
mock_seg_svc.get_child_chunk_by_segment_ref.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{segment_id}/child_chunks/cc-id",
|
||||
|
||||
@ -422,6 +422,13 @@ class TestLLMGenerator:
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert result == {"modified": "prompt"}
|
||||
stmt = mock_scalar.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "messages.app_id" in statement
|
||||
assert "apps.tenant_id" in statement
|
||||
assert "flow_id" in compiled.params.values()
|
||||
assert "tenant_id" in compiled.params.values()
|
||||
|
||||
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
@ -439,12 +446,26 @@ class TestLLMGenerator:
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert result == {"modified": "prompt"}
|
||||
stmt = mock_scalar.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "messages.app_id" in statement
|
||||
assert "apps.tenant_id" in statement
|
||||
assert "flow_id" in compiled.params.values()
|
||||
assert "tenant_id" in compiled.params.values()
|
||||
|
||||
def test_instruction_modify_workflow_app_not_found(self):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.scalar.return_value = None
|
||||
with pytest.raises(ValueError, match="App not found."):
|
||||
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
|
||||
stmt = mock_session.return_value.scalar.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "apps.id" in statement
|
||||
assert "apps.tenant_id" in statement
|
||||
assert "f" in compiled.params.values()
|
||||
assert "t" in compiled.params.values()
|
||||
|
||||
def test_instruction_modify_workflow_no_workflow(self):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
|
||||
@ -1,111 +0,0 @@
|
||||
"""Tests for the uuidv7 SQL migration's PostgreSQL 18 compatibility guard.
|
||||
|
||||
The migration file name is not a valid Python identifier (it starts with a date and
|
||||
contains hyphens), so it is loaded directly from its path. The ``models`` import at the
|
||||
top of the migration is stubbed because the migration never uses it during
|
||||
``upgrade()``/``downgrade()`` and pulling in the real package would require a full app
|
||||
context.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
MIGRATION_PATH = (
|
||||
Path(__file__).resolve().parents[3]
|
||||
/ "migrations"
|
||||
/ "versions"
|
||||
/ "2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_migration():
|
||||
# The migration does `import models as models` but never references it, so a stub is
|
||||
# enough and keeps the test free of any database/app configuration.
|
||||
sys.modules.setdefault("models", types.ModuleType("models"))
|
||||
spec = importlib.util.spec_from_file_location("uuidv7_pg18_migration_under_test", MIGRATION_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _make_bind(dialect_name):
|
||||
bind = mock.MagicMock()
|
||||
bind.dialect.name = dialect_name
|
||||
return bind
|
||||
|
||||
|
||||
def _executed_sql(fake_op):
|
||||
return [str(call.args[0]) for call in fake_op.execute.call_args_list]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration():
|
||||
return _load_migration()
|
||||
|
||||
|
||||
def test_upgrade_creates_both_functions_when_native_uuidv7_absent(migration):
|
||||
# PostgreSQL 13 to 17: no native pg_catalog.uuidv7(), so both functions are created.
|
||||
# The DO block contains the CREATE FUNCTION guarded by IF NOT EXISTS, and
|
||||
# uuidv7_boundary is created unconditionally.
|
||||
bind = _make_bind("postgresql")
|
||||
with mock.patch.object(migration, "op") as fake_op:
|
||||
fake_op.get_bind.return_value = bind
|
||||
migration.upgrade()
|
||||
|
||||
sql = _executed_sql(fake_op)
|
||||
assert any("CREATE FUNCTION public.uuidv7()" in stmt for stmt in sql)
|
||||
assert any("CREATE FUNCTION public.uuidv7_boundary(timestamptz)" in stmt for stmt in sql)
|
||||
|
||||
|
||||
def test_upgrade_skips_uuidv7_but_keeps_boundary_when_native_present(migration):
|
||||
# PostgreSQL 18: native pg_catalog.uuidv7() exists, so the DO block must guard
|
||||
# the CREATE FUNCTION with an IF NOT EXISTS check against pg_catalog.
|
||||
# uuidv7_boundary is still missing and has to be created unconditionally.
|
||||
bind = _make_bind("postgresql")
|
||||
with mock.patch.object(migration, "op") as fake_op:
|
||||
fake_op.get_bind.return_value = bind
|
||||
migration.upgrade()
|
||||
|
||||
sql = _executed_sql(fake_op)
|
||||
# The DO block must contain the pg_catalog existence check.
|
||||
do_block = next((stmt for stmt in sql if "DO $do$" in stmt), None)
|
||||
assert do_block is not None
|
||||
assert "pg_catalog" in do_block
|
||||
assert "uuidv7" in do_block
|
||||
assert "IF NOT EXISTS" in do_block
|
||||
# uuidv7_boundary is always created (not guarded by the DO block).
|
||||
assert any("CREATE FUNCTION public.uuidv7_boundary(timestamptz)" in stmt for stmt in sql)
|
||||
|
||||
|
||||
def test_upgrade_is_noop_on_non_postgres(migration):
|
||||
bind = _make_bind("sqlite")
|
||||
with mock.patch.object(migration, "op") as fake_op:
|
||||
fake_op.get_bind.return_value = bind
|
||||
migration.upgrade()
|
||||
|
||||
fake_op.execute.assert_not_called()
|
||||
|
||||
|
||||
def test_downgrade_uses_if_exists_and_public_schema(migration):
|
||||
bind = _make_bind("postgresql")
|
||||
with mock.patch.object(migration, "op") as fake_op:
|
||||
fake_op.get_bind.return_value = bind
|
||||
migration.downgrade()
|
||||
|
||||
sql = _executed_sql(fake_op)
|
||||
assert "DROP FUNCTION IF EXISTS public.uuidv7()" in sql
|
||||
assert "DROP FUNCTION IF EXISTS public.uuidv7_boundary(timestamptz)" in sql
|
||||
|
||||
|
||||
def test_downgrade_is_noop_on_non_postgres(migration):
|
||||
bind = _make_bind("sqlite")
|
||||
with mock.patch.object(migration, "op") as fake_op:
|
||||
fake_op.get_bind.return_value = bind
|
||||
migration.downgrade()
|
||||
|
||||
fake_op.execute.assert_not_called()
|
||||
@ -12,6 +12,7 @@ from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate, Pipeli
|
||||
from models.workflow import Workflow
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.workflow_ref_service import WorkflowRefService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -369,6 +370,38 @@ def test_update_workflow_returns_none_when_not_found(
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_update_workflow_with_ref_scopes_lookup_to_pipeline(
|
||||
mocker: MockerFixture, rag_pipeline_service: RagPipelineService
|
||||
) -> None:
|
||||
workflow = SimpleNamespace(
|
||||
id="wf-1", marked_name="", marked_comment="", updated_by=None, updated_at=None, disallowed="original"
|
||||
)
|
||||
pipeline = SimpleNamespace(id="pipeline-1", tenant_id="t1")
|
||||
workflow_ref = WorkflowRefService.create_pipeline_workflow_ref(pipeline, "wf-1")
|
||||
session = mocker.Mock()
|
||||
session.scalar.return_value = workflow
|
||||
|
||||
result = rag_pipeline_service.update_workflow(
|
||||
session=session,
|
||||
workflow_id="wf-1",
|
||||
tenant_id="t1",
|
||||
account_id="u1",
|
||||
data={"marked_name": "v1"},
|
||||
workflow_ref=workflow_ref,
|
||||
)
|
||||
|
||||
stmt = session.scalar.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "workflows.id" in statement
|
||||
assert "workflows.tenant_id" in statement
|
||||
assert "workflows.app_id" in statement
|
||||
assert "wf-1" in compiled.params.values()
|
||||
assert "t1" in compiled.params.values()
|
||||
assert "pipeline-1" in compiled.params.values()
|
||||
assert result is workflow
|
||||
|
||||
|
||||
# --- get_rag_pipeline_paginate_workflow_runs ---
|
||||
|
||||
|
||||
@ -1627,6 +1660,8 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(
|
||||
def __init__(self):
|
||||
self._values = {
|
||||
("sys", "invoke_from"): SimpleNamespace(value=InvokeFrom.PUBLISHED_PIPELINE),
|
||||
("sys", "app_id"): SimpleNamespace(value="pipeline-1"),
|
||||
("sys", "dataset_id"): SimpleNamespace(value="dataset-1"),
|
||||
("sys", "document_id"): SimpleNamespace(value="doc-1"),
|
||||
}
|
||||
|
||||
@ -1660,7 +1695,8 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(
|
||||
)
|
||||
|
||||
document = SimpleNamespace(indexing_status="waiting", error=None)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
|
||||
scalar_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=document)
|
||||
get_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get")
|
||||
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
|
||||
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
|
||||
@ -1672,6 +1708,19 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(
|
||||
)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
stmt = scalar_mock.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "documents.id" in statement
|
||||
assert "documents.tenant_id" in statement
|
||||
assert "documents.dataset_id" in statement
|
||||
assert "datasets.tenant_id" in statement
|
||||
assert "datasets.pipeline_id" in statement
|
||||
assert "doc-1" in compiled.params.values()
|
||||
assert "t1" in compiled.params.values()
|
||||
assert "dataset-1" in compiled.params.values()
|
||||
assert "pipeline-1" in compiled.params.values()
|
||||
get_mock.assert_not_called()
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error == "boom"
|
||||
add_mock.assert_called_once_with(document)
|
||||
|
||||
@ -41,9 +41,10 @@ def _make_message(message_id: str = "msg-1", app_id: str = "app-1") -> MagicMock
|
||||
return message
|
||||
|
||||
|
||||
def _make_annotation(annotation_id: str = "ann-1") -> MagicMock:
|
||||
def _make_annotation(annotation_id: str = "ann-1", app_id: str = "app-1") -> MagicMock:
|
||||
annotation = MagicMock(spec=MessageAnnotation)
|
||||
annotation.id = annotation_id
|
||||
annotation.app_id = app_id
|
||||
annotation.content = ""
|
||||
annotation.question = ""
|
||||
annotation.question_text = ""
|
||||
@ -66,6 +67,15 @@ def _make_file(content: bytes) -> FileStorage:
|
||||
return FileStorage(stream=BytesIO(content))
|
||||
|
||||
|
||||
def _assert_statement_binds_annotation(stmt: Any, annotation_id: str, app_id: str) -> None:
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "message_annotations.id" in statement
|
||||
assert "message_annotations.app_id" in statement
|
||||
assert annotation_id in compiled.params.values()
|
||||
assert app_id in compiled.params.values()
|
||||
|
||||
|
||||
class TestAppAnnotationServiceUpInsert:
|
||||
"""Test suite for up_insert_app_annotation_from_message."""
|
||||
|
||||
@ -541,8 +551,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
|
||||
patch("services.annotation_service.db") as mock_db,
|
||||
):
|
||||
mock_db.session.scalar.return_value = app
|
||||
mock_db.session.get.return_value = None
|
||||
mock_db.session.scalar.side_effect = [app, None]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFound):
|
||||
@ -576,8 +585,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
|
||||
patch("services.annotation_service.db") as mock_db,
|
||||
):
|
||||
mock_db.session.scalar.return_value = app
|
||||
mock_db.session.get.return_value = annotation
|
||||
mock_db.session.scalar.side_effect = [app, annotation]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
@ -598,8 +606,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
patch("services.annotation_service.db") as mock_db,
|
||||
patch("services.annotation_service.update_annotation_to_index_task") as mock_task,
|
||||
):
|
||||
mock_db.session.scalar.side_effect = [app, setting]
|
||||
mock_db.session.get.return_value = annotation
|
||||
mock_db.session.scalar.side_effect = [app, annotation, setting]
|
||||
|
||||
# Act
|
||||
result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id, mock_db.session)
|
||||
@ -608,6 +615,8 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
assert result == annotation
|
||||
assert annotation.content == "hello"
|
||||
assert annotation.question == "q1"
|
||||
_assert_statement_binds_annotation(mock_db.session.scalar.call_args_list[1].args[0], annotation.id, app.id)
|
||||
mock_db.session.get.assert_not_called()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_task.delay.assert_called_once_with(
|
||||
annotation.id,
|
||||
@ -632,8 +641,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
patch("services.annotation_service.db") as mock_db,
|
||||
patch("services.annotation_service.delete_annotation_index_task") as mock_task,
|
||||
):
|
||||
mock_db.session.scalar.side_effect = [app, setting]
|
||||
mock_db.session.get.return_value = annotation
|
||||
mock_db.session.scalar.side_effect = [app, annotation, setting]
|
||||
|
||||
scalars_result = MagicMock()
|
||||
scalars_result.all.return_value = [history1, history2]
|
||||
@ -643,6 +651,8 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
AppAnnotationService.delete_app_annotation(app.id, annotation.id, mock_db.session)
|
||||
|
||||
# Assert
|
||||
_assert_statement_binds_annotation(mock_db.session.scalar.call_args_list[1].args[0], annotation.id, app.id)
|
||||
mock_db.session.get.assert_not_called()
|
||||
mock_db.session.delete.assert_any_call(annotation)
|
||||
mock_db.session.delete.assert_any_call(history1)
|
||||
mock_db.session.delete.assert_any_call(history2)
|
||||
@ -679,8 +689,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
|
||||
patch("services.annotation_service.db") as mock_db,
|
||||
):
|
||||
mock_db.session.scalar.return_value = app
|
||||
mock_db.session.get.return_value = None
|
||||
mock_db.session.scalar.side_effect = [app, None]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFound):
|
||||
@ -748,6 +757,13 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
|
||||
# Assert
|
||||
assert result == {"deleted_count": 2}
|
||||
fetch_stmt = mock_db.session.execute.call_args_list[0].args[0]
|
||||
compiled = fetch_stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "message_annotations.id IN" in statement
|
||||
assert "message_annotations.app_id" in statement
|
||||
assert ["ann-1", "ann-2"] in compiled.params.values()
|
||||
assert app.id in compiled.params.values()
|
||||
mock_task.delay.assert_called_once_with(annotation1.id, app.id, tenant_id, setting.collection_binding_id)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@ -1121,8 +1137,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
|
||||
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
|
||||
patch("services.annotation_service.db") as mock_db,
|
||||
):
|
||||
mock_db.session.scalar.return_value = app
|
||||
mock_db.session.get.return_value = annotation
|
||||
mock_db.session.scalar.side_effect = [app, annotation]
|
||||
mock_db.paginate.return_value = pagination
|
||||
|
||||
# Act
|
||||
@ -1131,6 +1146,8 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
|
||||
# Assert
|
||||
assert items == ["h1"]
|
||||
assert total == 2
|
||||
_assert_statement_binds_annotation(mock_db.session.scalar.call_args_list[1].args[0], annotation.id, app.id)
|
||||
mock_db.session.get.assert_not_called()
|
||||
|
||||
def test_get_annotation_hit_histories_should_raise_not_found_when_annotation_missing(self) -> None:
|
||||
"""Test missing annotation raises NotFound."""
|
||||
@ -1142,8 +1159,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
|
||||
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
|
||||
patch("services.annotation_service.db") as mock_db,
|
||||
):
|
||||
mock_db.session.scalar.return_value = app
|
||||
mock_db.session.get.return_value = None
|
||||
mock_db.session.scalar.side_effect = [app, None]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFound):
|
||||
|
||||
@ -523,7 +523,7 @@ class TestAudioServiceTTS:
|
||||
message_id = "00000000-0000-0000-0000-000000000001"
|
||||
message = factory.create_message_mock(message_id=message_id, answer="Message answer")
|
||||
session = MagicMock()
|
||||
session.get.return_value = message
|
||||
session.scalar.return_value = message
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
@ -536,11 +536,25 @@ class TestAudioServiceTTS:
|
||||
session=session,
|
||||
message_id=message_id,
|
||||
voice="message-voice",
|
||||
message_end_user_id="end-user-1",
|
||||
message_account_id="account-1",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"message audio"
|
||||
session.get.assert_called_once_with(Message, message_id)
|
||||
session.scalar.assert_called_once()
|
||||
session.get.assert_not_called()
|
||||
stmt = session.scalar.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "messages.id" in statement
|
||||
assert "messages.app_id" in statement
|
||||
assert "messages.from_end_user_id" in statement
|
||||
assert "messages.from_account_id" in statement
|
||||
assert message_id in compiled.params.values()
|
||||
assert app.id in compiled.params.values()
|
||||
assert "end-user-1" in compiled.params.values()
|
||||
assert "account-1" in compiled.params.values()
|
||||
mock_model_instance.invoke_tts.assert_called_once_with(
|
||||
content_text="Message answer",
|
||||
voice="message-voice",
|
||||
|
||||
@ -103,6 +103,33 @@ class TestDocumentServiceMutations:
|
||||
|
||||
assert DocumentService.check_archived(document) is expected
|
||||
|
||||
def test_delete_documents_limits_query_and_cleanup_to_dataset_ref(self):
|
||||
dataset = _make_dataset(dataset_id="dataset-1", tenant_id="tenant-1")
|
||||
dataset.doc_form = "paragraph_index"
|
||||
document = _make_document(document_id="doc-1", dataset_id=dataset.id, tenant_id=dataset.tenant_id)
|
||||
document.data_source_info_dict = {}
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.batch_clean_document_task") as clean_task,
|
||||
):
|
||||
mock_db.session.scalars.return_value.all.return_value = [document]
|
||||
|
||||
DocumentService.delete_documents(dataset, ["doc-1", "other-doc"])
|
||||
|
||||
stmt = mock_db.session.scalars.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "documents.id IN" in statement
|
||||
assert "documents.tenant_id" in statement
|
||||
assert "documents.dataset_id" in statement
|
||||
assert ["doc-1", "other-doc"] in compiled.params.values()
|
||||
assert dataset.tenant_id in compiled.params.values()
|
||||
assert dataset.id in compiled.params.values()
|
||||
mock_db.session.delete.assert_called_once_with(document)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
clean_task.delay.assert_called_once_with(["doc-1"], dataset.id, dataset.doc_form, [])
|
||||
|
||||
def test_rename_document_raises_when_dataset_is_missing(self, rename_account_context):
|
||||
with patch.object(DatasetService, "get_dataset", return_value=None):
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Unit tests for SegmentService behaviors in dataset_service."""
|
||||
|
||||
from services.dataset_ref_service import DatasetRef, DatasetRefService
|
||||
|
||||
from .dataset_service_test_helpers import (
|
||||
Account,
|
||||
ChildChunk,
|
||||
@ -24,6 +26,38 @@ from .dataset_service_test_helpers import (
|
||||
)
|
||||
|
||||
|
||||
def _make_segment_ref(segment_id: str = "segment-1"):
|
||||
dataset = _make_dataset()
|
||||
document = _make_document(dataset_id=dataset.id, tenant_id=dataset.tenant_id)
|
||||
dataset_ref = DatasetRefService.create_dataset_ref(dataset)
|
||||
document_ref = DatasetRefService.create_document_ref(dataset_ref, document)
|
||||
assert document_ref is not None
|
||||
return DatasetRefService.create_segment_ref(document_ref, segment_id)
|
||||
|
||||
|
||||
class TestDatasetRefService:
|
||||
"""Unit tests for typed dataset resource refs."""
|
||||
|
||||
def test_dataset_ref_requires_service_ctor_token(self):
|
||||
with pytest.raises(ValueError, match="DatasetRef must be created by DatasetRefService"):
|
||||
DatasetRef("tenant-1", "dataset-1", object())
|
||||
|
||||
def test_create_document_ref_rejects_document_outside_dataset(self):
|
||||
dataset = _make_dataset(dataset_id="dataset-1", tenant_id="tenant-1")
|
||||
document = _make_document(document_id="doc-1", dataset_id="other-dataset", tenant_id="tenant-1")
|
||||
dataset_ref = DatasetRefService.create_dataset_ref(dataset)
|
||||
|
||||
assert DatasetRefService.create_document_ref(dataset_ref, document) is None
|
||||
|
||||
def test_create_segment_ref_carries_full_parent_chain(self):
|
||||
segment_ref = _make_segment_ref()
|
||||
|
||||
assert segment_ref.tenant_id == "tenant-1"
|
||||
assert segment_ref.dataset_id == "dataset-1"
|
||||
assert segment_ref.document_id == "doc-1"
|
||||
assert segment_ref.segment_id == "segment-1"
|
||||
|
||||
|
||||
class TestSegmentServiceChildChunks:
|
||||
"""Unit tests for child-chunk CRUD helpers."""
|
||||
|
||||
@ -257,6 +291,23 @@ class TestSegmentServiceQueries:
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_child_chunk_by_segment_ref_uses_full_ownership_chain(self):
|
||||
child_chunk = _make_child_chunk()
|
||||
segment_ref = _make_segment_ref()
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.scalar.return_value = child_chunk
|
||||
result = SegmentService.get_child_chunk_by_segment_ref("child-a", segment_ref)
|
||||
|
||||
assert result is child_chunk
|
||||
stmt = mock_db.session.scalar.call_args.args[0]
|
||||
sql = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "child_chunks.id = 'child-a'" in sql
|
||||
assert "child_chunks.tenant_id = 'tenant-1'" in sql
|
||||
assert "child_chunks.dataset_id = 'dataset-1'" in sql
|
||||
assert "child_chunks.document_id = 'doc-1'" in sql
|
||||
assert "child_chunks.segment_id = 'segment-1'" in sql
|
||||
|
||||
def test_get_segments_uses_status_and_keyword_filters(self):
|
||||
paginated = SimpleNamespace(items=["segment"], total=1)
|
||||
|
||||
@ -304,6 +355,32 @@ class TestSegmentServiceQueries:
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_segment_by_ref_uses_full_ownership_chain(self):
|
||||
segment = DocumentSegment(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="dataset-1",
|
||||
document_id="doc-1",
|
||||
position=1,
|
||||
content="segment",
|
||||
word_count=7,
|
||||
tokens=2,
|
||||
created_by="user-1",
|
||||
)
|
||||
segment.id = "segment-1"
|
||||
segment_ref = _make_segment_ref()
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.scalar.return_value = segment
|
||||
result = SegmentService.get_segment_by_ref(segment_ref)
|
||||
|
||||
assert result is segment
|
||||
stmt = mock_db.session.scalar.call_args.args[0]
|
||||
sql = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "document_segments.id = 'segment-1'" in sql
|
||||
assert "document_segments.tenant_id = 'tenant-1'" in sql
|
||||
assert "document_segments.dataset_id = 'dataset-1'" in sql
|
||||
assert "document_segments.document_id = 'doc-1'" in sql
|
||||
|
||||
def test_get_segments_by_document_and_dataset_returns_scalars_result(self):
|
||||
segment = DocumentSegment(
|
||||
tenant_id="tenant-1",
|
||||
|
||||
@ -5,7 +5,7 @@ from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models.enums import TagType
|
||||
from services.tag_service import TagBindingCreatePayload, TagBindingDeletePayload, TagService
|
||||
from services.tag_service import TagBindingCreatePayload, TagBindingDeletePayload, TagService, UpdateTagPayload
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -78,6 +78,71 @@ def test_delete_tag_binding_does_not_commit_when_no_rows_deleted(mocker: MockerF
|
||||
db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_update_tags_scopes_lookup_to_current_tenant_and_type(current_user, db_session):
|
||||
tag = SimpleNamespace(id="tag-1", name="old", type=TagType.KNOWLEDGE)
|
||||
db_session.scalar.side_effect = [tag, None]
|
||||
|
||||
result = TagService.update_tags(UpdateTagPayload(name="new"), "tag-1", db_session, tag_type=TagType.KNOWLEDGE)
|
||||
|
||||
stmt = db_session.scalar.call_args_list[0].args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "tags.id" in statement
|
||||
assert "tags.tenant_id" in statement
|
||||
assert "tags.type" in statement
|
||||
assert "tag-1" in compiled.params.values()
|
||||
assert current_user.current_tenant_id in compiled.params.values()
|
||||
assert TagType.KNOWLEDGE in compiled.params.values()
|
||||
assert result is tag
|
||||
assert tag.name == "new"
|
||||
db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_tag_binding_count_scopes_lookup_to_current_tenant_and_type(current_user, db_session):
|
||||
db_session.scalar.return_value = 3
|
||||
|
||||
result = TagService.get_tag_binding_count("tag-1", db_session, tag_type=TagType.KNOWLEDGE)
|
||||
|
||||
stmt = db_session.scalar.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "tag_bindings.tag_id" in statement
|
||||
assert "tags.tenant_id" in statement
|
||||
assert "tags.type" in statement
|
||||
assert "tag-1" in compiled.params.values()
|
||||
assert current_user.current_tenant_id in compiled.params.values()
|
||||
assert TagType.KNOWLEDGE in compiled.params.values()
|
||||
assert result == 3
|
||||
|
||||
|
||||
def test_delete_tag_scopes_lookup_and_bindings_to_current_tenant(current_user, db_session):
|
||||
tag = SimpleNamespace(id="tag-1", name="old", type=TagType.KNOWLEDGE)
|
||||
binding = SimpleNamespace(id="binding-1")
|
||||
db_session.scalar.return_value = tag
|
||||
db_session.scalars.return_value.all.return_value = [binding]
|
||||
|
||||
TagService.delete_tag("tag-1", db_session, tag_type=TagType.KNOWLEDGE)
|
||||
|
||||
tag_stmt = db_session.scalar.call_args.args[0]
|
||||
tag_compiled = tag_stmt.compile()
|
||||
assert "tags.id" in str(tag_compiled)
|
||||
assert "tags.tenant_id" in str(tag_compiled)
|
||||
assert "tags.type" in str(tag_compiled)
|
||||
assert "tag-1" in tag_compiled.params.values()
|
||||
assert current_user.current_tenant_id in tag_compiled.params.values()
|
||||
assert TagType.KNOWLEDGE in tag_compiled.params.values()
|
||||
|
||||
binding_stmt = db_session.scalars.call_args.args[0]
|
||||
binding_compiled = binding_stmt.compile()
|
||||
assert "tag_bindings.tag_id" in str(binding_compiled)
|
||||
assert "tag_bindings.tenant_id" in str(binding_compiled)
|
||||
assert "tag-1" in binding_compiled.params.values()
|
||||
assert current_user.current_tenant_id in binding_compiled.params.values()
|
||||
db_session.delete.assert_any_call(tag)
|
||||
db_session.delete.assert_any_call(binding)
|
||||
db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_target_ids_by_tag_ids_returns_empty_without_query_for_empty_input(db_session):
|
||||
result = TagService.get_target_ids_by_tag_ids(TagType.SNIPPET, "tenant-1", [], db_session)
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ from models.model import App, AppMode
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError
|
||||
from services.errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
|
||||
from services.workflow_ref_service import WorkflowRefService
|
||||
from services.workflow_service import (
|
||||
WorkflowService,
|
||||
_rebuild_file_for_user_inputs_in_start_node,
|
||||
@ -1052,6 +1053,39 @@ class TestWorkflowService:
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_update_workflow_with_ref_scopes_lookup_to_app(self, workflow_service: WorkflowService):
|
||||
"""Test update_workflow includes the trusted app owner in the lookup."""
|
||||
workflow_id = "workflow-123"
|
||||
tenant_id = "tenant-456"
|
||||
app_id = "app-789"
|
||||
account_id = "user-123"
|
||||
workflow_ref = WorkflowRefService.create_app_workflow_ref(
|
||||
SimpleNamespace(id=app_id, tenant_id=tenant_id), workflow_id
|
||||
)
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id)
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = mock_workflow
|
||||
|
||||
result = workflow_service.update_workflow(
|
||||
session=mock_session,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
data={"marked_name": "Updated Name"},
|
||||
workflow_ref=workflow_ref,
|
||||
)
|
||||
|
||||
stmt = mock_session.scalar.call_args.args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "workflows.id" in statement
|
||||
assert "workflows.tenant_id" in statement
|
||||
assert "workflows.app_id" in statement
|
||||
assert workflow_id in compiled.params.values()
|
||||
assert tenant_id in compiled.params.values()
|
||||
assert app_id in compiled.params.values()
|
||||
assert result == mock_workflow
|
||||
|
||||
# ==================== Delete Workflow Tests ====================
|
||||
# These tests verify workflow deletion with safety checks
|
||||
|
||||
@ -1085,6 +1119,34 @@ class TestWorkflowService:
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once_with(mock_workflow)
|
||||
|
||||
def test_delete_workflow_with_ref_scopes_lookup_to_app(self, workflow_service: WorkflowService):
|
||||
"""Test delete_workflow includes the trusted app owner in the lookup."""
|
||||
workflow_id = "workflow-123"
|
||||
tenant_id = "tenant-456"
|
||||
app_id = "app-789"
|
||||
workflow_ref = WorkflowRefService.create_app_workflow_ref(
|
||||
SimpleNamespace(id=app_id, tenant_id=tenant_id), workflow_id
|
||||
)
|
||||
mock_workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock(workflow_id=workflow_id, version="v1")
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.side_effect = [mock_workflow, None, None]
|
||||
|
||||
result = workflow_service.delete_workflow(
|
||||
session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id, workflow_ref=workflow_ref
|
||||
)
|
||||
|
||||
stmt = mock_session.scalar.call_args_list[0].args[0]
|
||||
compiled = stmt.compile()
|
||||
statement = str(compiled)
|
||||
assert "workflows.id" in statement
|
||||
assert "workflows.tenant_id" in statement
|
||||
assert "workflows.app_id" in statement
|
||||
assert workflow_id in compiled.params.values()
|
||||
assert tenant_id in compiled.params.values()
|
||||
assert app_id in compiled.params.values()
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once_with(mock_workflow)
|
||||
|
||||
def test_delete_workflow_draft_raises_error(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test delete_workflow raises error when trying to delete draft.
|
||||
|
||||
Reference in New Issue
Block a user