mirror of
https://github.com/langgenius/dify.git
synced 2026-05-24 19:07:53 +08:00
Compare commits
1 Commits
main
...
codex/fix-
| Author | SHA1 | Date | |
|---|---|---|---|
| 5d5da26948 |
@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
@ -82,7 +80,7 @@ class AgentRosterDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
def get(self, agent_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
|
||||
|
||||
@ -91,7 +89,7 @@ class AgentRosterDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, agent_id: UUID):
|
||||
def patch(self, agent_id):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return _agent_roster_service().update_roster_agent(
|
||||
@ -102,7 +100,7 @@ class AgentRosterDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, agent_id: UUID):
|
||||
def delete(self, agent_id):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
|
||||
return "", 204
|
||||
@ -113,7 +111,7 @@ class AgentRosterVersionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
def get(self, agent_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
|
||||
|
||||
@ -123,7 +121,7 @@ class AgentRosterVersionDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID, version_id: UUID):
|
||||
def get(self, agent_id, version_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_agent_version_detail(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import flask_restx
|
||||
from flask_restx import Resource
|
||||
@ -156,7 +155,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id: UUID):
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
|
||||
@ -165,7 +164,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id: UUID):
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
@ -181,9 +180,9 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id: UUID, api_key_id: UUID):
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(str(resource_id), str(api_key_id))
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@ -196,7 +195,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id: UUID):
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
|
||||
@ -205,7 +204,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id: UUID):
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
@ -221,9 +220,9 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id: UUID, api_key_id: UUID):
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(str(resource_id), str(api_key_id))
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
|
||||
@ -159,15 +159,13 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id: UUID, annotation_setting_id: UUID):
|
||||
annotation_setting_id_str = str(annotation_setting_id)
|
||||
def post(self, app_id: UUID, annotation_setting_id):
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
|
||||
result = AppAnnotationService.update_app_annotation_setting(
|
||||
str(app_id), annotation_setting_id_str, setting_args
|
||||
)
|
||||
result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -183,9 +181,9 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id: UUID, job_id: UUID, action: str):
|
||||
job_id_str = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{job_id_str}"
|
||||
def get(self, app_id: UUID, job_id, action):
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job does not exist.")
|
||||
@ -193,10 +191,10 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ""
|
||||
if job_status == "error":
|
||||
app_annotation_error_key = f"{action}_app_annotation_error_{job_id_str}"
|
||||
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {"job_id": job_id_str, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
|
||||
@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
|
||||
@ -131,7 +131,7 @@ class CompletionMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model, task_id: str):
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@ -212,7 +212,7 @@ class ChatMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, task_id: str):
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, request
|
||||
@ -165,10 +164,10 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
return ConversationMessageDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(app_model, conversation_id), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@ -182,12 +181,12 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id: UUID):
|
||||
def delete(self, app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id_str = str(conversation_id)
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id_str, current_user)
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -318,10 +317,10 @@ class ChatConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
return ConversationDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(app_model, conversation_id), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@ -335,12 +334,12 @@ class ChatConversationDetailApi(Resource):
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id: UUID):
|
||||
def delete(self, app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id_str = str(conversation_id)
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id_str, current_user)
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -163,7 +162,7 @@ class AppMCPServerRefreshController(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, server_id: UUID):
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -337,13 +336,13 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, message_id: UUID):
|
||||
def get(self, app_model, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, message_id=message_id_str, user=current_user, invoke_from=InvokeFrom.DEBUGGER
|
||||
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
@ -418,10 +417,10 @@ class MessageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id_str, Message.app_id == app_model.id).limit(1)
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -368,14 +367,14 @@ class WorkflowRunDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run detail
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
run_id = str(run_id)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id_str)
|
||||
workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id)
|
||||
if workflow_run is None:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
@ -397,17 +396,17 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
run_id = str(run_id)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
user = cast("Account | EndUser", current_user)
|
||||
node_executions = workflow_run_service.get_workflow_run_node_executions(
|
||||
app_model=app_model,
|
||||
run_id=run_id_str,
|
||||
run_id=run_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -89,10 +87,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.response(204, "Binding deleted successfully")
|
||||
def delete(self, binding_id: UUID):
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
@ -159,15 +158,16 @@ class OAuthDataSourceSync(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str, binding_id: UUID):
|
||||
binding_id_str = str(binding_id)
|
||||
def get(self, provider, binding_id):
|
||||
provider = str(provider)
|
||||
binding_id = str(binding_id)
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id_str)
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
@ -294,7 +293,7 @@ class DataSourceNotionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, page_id: UUID, page_type: str):
|
||||
def get(self, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
|
||||
@ -307,11 +306,11 @@ class DataSourceNotionApi(Resource):
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
page_id_str = str(page_id)
|
||||
page_id = str(page_id)
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id="",
|
||||
notion_obj_id=page_id_str,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=current_tenant_id,
|
||||
@ -368,7 +367,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -386,7 +385,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -512,7 +511,7 @@ class DatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -566,7 +565,7 @@ class DatasetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID):
|
||||
def patch(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -614,7 +613,7 @@ class DatasetApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Dataset deleted successfully")
|
||||
def delete(self, dataset_id: UUID):
|
||||
def delete(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@ -644,7 +643,7 @@ class DatasetUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
|
||||
@ -664,7 +663,7 @@ class DatasetQueryApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -804,7 +803,7 @@ class DatasetRelatedAppListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -840,11 +839,11 @@ class DatasetIndexingStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
|
||||
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
@ -952,15 +951,15 @@ class DatasetApiDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, api_key_id: UUID):
|
||||
def delete(self, api_key_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id_str = str(api_key_id)
|
||||
api_key_id = str(api_key_id)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
.where(
|
||||
ApiToken.tenant_id == current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id_str,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
@ -985,7 +984,7 @@ class DatasetEnableApiApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, dataset_id: UUID, status: str):
|
||||
def post(self, dataset_id, status):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
|
||||
@ -1037,7 +1036,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type: str):
|
||||
def get(self, vector_type):
|
||||
return dump_response(
|
||||
RetrievalSettingResponse,
|
||||
_get_retrieval_methods_by_vector_type(vector_type, is_mock=True),
|
||||
@ -1054,7 +1053,7 @@ class DatasetErrorDocs(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -1079,7 +1078,7 @@ class DatasetPermissionUserListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -1109,7 +1108,7 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -5,7 +5,6 @@ from collections.abc import Sequence
|
||||
from contextlib import ExitStack
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
@ -316,9 +315,9 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
raw_args = request.args.to_dict()
|
||||
param = DocumentDatasetListParam.model_validate(raw_args)
|
||||
page = param.page
|
||||
@ -343,7 +342,7 @@ class DatasetDocumentListApi(Resource):
|
||||
)
|
||||
except (ArgumentTypeError, ValueError, Exception):
|
||||
fetch = False
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@ -352,7 +351,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
query = select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
|
||||
query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
@ -373,7 +372,7 @@ class DatasetDocumentListApi(Resource):
|
||||
sa.select(
|
||||
DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")
|
||||
)
|
||||
.where(DocumentSegment.dataset_id == dataset_id_str)
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
@ -445,11 +444,11 @@ class DatasetDocumentListApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -473,7 +472,7 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -491,9 +490,9 @@ class DatasetDocumentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Documents deleted successfully")
|
||||
def delete(self, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
def delete(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
@ -583,11 +582,11 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id, document_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@ -625,7 +624,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule_dict,
|
||||
document.doc_form,
|
||||
"English",
|
||||
dataset_id_str,
|
||||
dataset_id,
|
||||
)
|
||||
return estimate_response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
@ -648,10 +647,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
def get(self, dataset_id, batch):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
if not documents:
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
@ -725,7 +725,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule_dict,
|
||||
document.doc_form,
|
||||
"English",
|
||||
dataset_id_str,
|
||||
dataset_id,
|
||||
)
|
||||
return response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
@ -745,9 +745,10 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
def get(self, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -799,16 +800,16 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -817,7 +818,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
total_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -860,10 +861,10 @@ class DocumentApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
metadata = request.args.get("metadata", "all")
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
@ -872,7 +873,7 @@ class DocumentApi(DocumentResource):
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
response = {
|
||||
"id": document.id,
|
||||
@ -906,7 +907,7 @@ class DocumentApi(DocumentResource):
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
response = {
|
||||
"id": document.id,
|
||||
@ -949,16 +950,16 @@ class DocumentApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
@ -1002,10 +1003,10 @@ class DocumentBatchDownloadZipApi(DocumentResource):
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=dataset_id_str,
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids,
|
||||
tenant_id=current_tenant_id,
|
||||
current_user=current_user,
|
||||
@ -1043,11 +1044,11 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
|
||||
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -1091,11 +1092,11 @@ class DocumentMetadataApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id: UUID, document_id: UUID):
|
||||
def put(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
|
||||
|
||||
@ -1140,10 +1141,10 @@ class DocumentStatusApi(DocumentResource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@ -1178,16 +1179,16 @@ class DocumentPauseApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document paused successfully")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID):
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""pause document."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
@ -1213,14 +1214,14 @@ class DocumentRecoverApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document resumed successfully")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID):
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""recover document."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
@ -1246,11 +1247,11 @@ class DocumentRetryApi(DocumentResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
|
||||
@console_ns.response(204, "Documents retry started successfully")
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
retry_documents = []
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -1276,7 +1277,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
logger.exception("Failed to retry document, document id: %s", document_id)
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id_str, retry_documents)
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1288,7 +1289,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -1300,7 +1301,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
document = DocumentService.rename_document(str(dataset_id), str(document_id), payload.name)
|
||||
document = DocumentService.rename_document(dataset_id, document_id, payload.name)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
@ -1313,15 +1314,15 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id, document_id):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.tenant_id != current_tenant_id:
|
||||
@ -1332,7 +1333,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
# sync document
|
||||
DocumentService.sync_website_document(dataset_id_str, document)
|
||||
DocumentService.sync_website_document(dataset_id, document)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -1342,19 +1343,19 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document_id_str)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
@ -1391,7 +1392,7 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
@ -1400,10 +1401,10 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@ -1437,7 +1438,7 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
|
||||
|
||||
# Verify all documents exist and belong to the dataset
|
||||
documents = DocumentService.get_documents_by_ids(dataset_id_str, document_list)
|
||||
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
|
||||
|
||||
if len(documents) != len(document_list):
|
||||
found_ids = {doc.id for doc in documents}
|
||||
@ -1451,7 +1452,7 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
if documents_to_update:
|
||||
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
|
||||
DocumentService.update_documents_need_summary(
|
||||
dataset_id=dataset_id_str,
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids_to_update,
|
||||
need_summary=True,
|
||||
)
|
||||
@ -1464,11 +1465,11 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
continue
|
||||
|
||||
# Dispatch async task
|
||||
generate_summary_index_task.delay(dataset_id_str, document.id)
|
||||
generate_summary_index_task.delay(dataset_id, document.id)
|
||||
logger.info(
|
||||
"Dispatched summary generation task for document %s in dataset %s",
|
||||
document.id,
|
||||
dataset_id_str,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
@ -1484,7 +1485,7 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id, document_id):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
@ -1498,11 +1499,11 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@ -1516,8 +1517,8 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
result = SummaryIndexService.get_document_summary_status_detail(
|
||||
document_id=document_id_str,
|
||||
dataset_id=dataset_id_str,
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
return result, 200
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
@ -115,12 +113,12 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@ -129,7 +127,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
@ -150,7 +148,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
query = (
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id_str,
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
@ -203,9 +201,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
if segment_ids:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
summary_records = SummaryIndexService.get_segments_summaries(
|
||||
segment_ids=segment_ids, dataset_id=dataset_id_str
|
||||
)
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
# Only include enabled summaries (already filtered by service)
|
||||
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
|
||||
|
||||
@ -230,19 +226,19 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Segments deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
def delete(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
segment_ids = request.args.getlist("segment_id")
|
||||
@ -266,15 +262,15 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check user's model setting
|
||||
@ -325,17 +321,17 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -365,7 +361,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
@ -376,19 +372,19 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
@ -408,10 +404,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
@ -432,33 +428,33 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Segment deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
@ -487,17 +483,17 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
@ -521,8 +517,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id),
|
||||
upload_file_id,
|
||||
dataset_id_str,
|
||||
document_id_str,
|
||||
dataset_id,
|
||||
document_id,
|
||||
current_tenant_id,
|
||||
current_user.id,
|
||||
)
|
||||
@ -534,7 +530,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id=None, dataset_id: UUID | None = None, document_id: UUID | None = None):
|
||||
def get(self, job_id=None, dataset_id=None, document_id=None):
|
||||
if job_id is None:
|
||||
raise NotFound("The job does not exist.")
|
||||
job_id = str(job_id)
|
||||
@ -555,24 +551,24 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
@ -610,26 +606,26 @@ class ChildChunkAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def get(self, dataset_id, document_id, segment_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
@ -646,9 +642,7 @@ class ChildChunkAddApi(Resource):
|
||||
limit = min(args.limit, 100)
|
||||
keyword = args.keyword
|
||||
|
||||
child_chunks = SegmentService.get_child_chunks(
|
||||
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
|
||||
)
|
||||
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
||||
return {
|
||||
"data": marshal(child_chunks.items, child_chunk_fields),
|
||||
"total": child_chunks.total,
|
||||
@ -662,26 +656,26 @@ class ChildChunkAddApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
@ -711,39 +705,39 @@ class ChildChunkUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Child chunk deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id_str),
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
@ -768,39 +762,39 @@ class ChildChunkUpdateApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id_str),
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
@ -177,11 +175,11 @@ class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
def get(self, external_knowledge_api_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
|
||||
external_knowledge_api_id_str, current_tenant_id
|
||||
external_knowledge_api_id, current_tenant_id
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise NotFound("API template not found.")
|
||||
@ -192,9 +190,9 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def patch(self, external_knowledge_api_id: UUID):
|
||||
def patch(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
ExternalDatasetService.validate_api_list(payload.settings)
|
||||
@ -202,7 +200,7 @@ class ExternalApiTemplateApi(Resource):
|
||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
external_knowledge_api_id=external_knowledge_api_id_str,
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
args=payload.model_dump(),
|
||||
)
|
||||
|
||||
@ -212,14 +210,14 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(204, "External knowledge API deleted successfully")
|
||||
def delete(self, external_knowledge_api_id: UUID):
|
||||
def delete(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id_str)
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
|
||||
return "", 204
|
||||
|
||||
|
||||
@ -232,12 +230,12 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
def get(self, external_knowledge_api_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||
external_knowledge_api_id_str, current_tenant_id
|
||||
external_knowledge_api_id, current_tenant_id
|
||||
)
|
||||
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
|
||||
|
||||
@ -288,7 +286,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field, field_validator
|
||||
@ -119,7 +118,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -43,7 +42,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -63,7 +62,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__]
|
||||
)
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -80,7 +79,7 @@ class DatasetMetadataApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, metadata_id: UUID):
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
|
||||
name = payload.name
|
||||
@ -100,7 +99,7 @@ class DatasetMetadataApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Metadata deleted successfully")
|
||||
def delete(self, dataset_id: UUID, metadata_id: UUID):
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
@ -137,7 +136,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Action completed successfully")
|
||||
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -165,7 +164,7 @@ class DocumentMetadataEditApi(Resource):
|
||||
204,
|
||||
"Documents metadata updated successfully",
|
||||
)
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -105,7 +105,7 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
def post(self, import_id: str):
|
||||
def post(self, import_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource
|
||||
@ -876,14 +875,14 @@ class RagPipelineWorkflowRunDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline, run_id: UUID):
|
||||
def get(self, pipeline: Pipeline, run_id):
|
||||
"""
|
||||
Get workflow run detail
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
run_id = str(run_id)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id_str)
|
||||
workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id)
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
@ -905,13 +904,13 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
run_id = str(run_id)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
user = cast("Account | EndUser", current_user)
|
||||
node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
|
||||
pipeline=pipeline,
|
||||
run_id=run_id_str,
|
||||
run_id=run_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
@ -961,15 +960,15 @@ class RagPipelineTransformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
rag_pipeline_transform_service = RagPipelineTransformService()
|
||||
result = rag_pipeline_transform_service.transform_dataset(dataset_id_str)
|
||||
result = rag_pipeline_transform_service.transform_dataset(dataset_id)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@ -133,7 +133,7 @@ class CompletionApi(InstalledAppResource):
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app, task_id: str):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
@ -209,7 +209,7 @@ class ChatApi(InstalledAppResource):
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app, task_id: str):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
@ -92,7 +91,7 @@ class ConversationListApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
def delete(self, installed_app, c_id: UUID):
|
||||
def delete(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -115,7 +114,7 @@ class ConversationApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||
def post(self, installed_app, c_id: UUID):
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -146,7 +145,7 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app, c_id: UUID):
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -170,7 +169,7 @@ class ConversationPinApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationUnPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app, c_id: UUID):
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
@ -96,18 +95,18 @@ class MessageListApi(InstalledAppResource):
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__])
|
||||
def post(self, installed_app, message_id: UUID):
|
||||
def post(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id_str,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
@ -124,13 +123,13 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
)
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
|
||||
def get(self, installed_app, message_id: UUID):
|
||||
def get(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
args = MoreLikeThisQuery.model_validate(request.args.to_dict())
|
||||
|
||||
@ -140,7 +139,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
response = AppGenerateService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id_str,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming,
|
||||
)
|
||||
@ -170,18 +169,18 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
)
|
||||
class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__])
|
||||
def get(self, installed_app, message_id: UUID):
|
||||
def get(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id_str, invoke_from=InvokeFrom.EXPLORE
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -67,15 +65,15 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
)
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Saved message deleted successfully")
|
||||
def delete(self, installed_app, message_id: UUID):
|
||||
def delete(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
SavedMessageService.delete(app_model, current_user, message_id_str)
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -153,7 +152,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, id: UUID):
|
||||
def get(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -169,7 +168,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, id: UUID):
|
||||
def post(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -197,7 +196,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: UUID):
|
||||
def delete(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -107,10 +106,10 @@ class FilePreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, file_id: UUID):
|
||||
file_id_str = str(file_id)
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
text = FileService(db.engine).get_file_preview(file_id_str, tenant_id)
|
||||
text = FileService(db.engine).get_file_preview(file_id, tenant_id)
|
||||
return {"content": text}
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -132,17 +131,17 @@ class TagUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id: UUID):
|
||||
def patch(self, tag_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id_str = str(tag_id)
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str)
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id_str)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
@ -155,10 +154,10 @@ class TagUpdateDeleteApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@console_ns.response(204, "Tag deleted successfully")
|
||||
def delete(self, tag_id: UUID):
|
||||
tag_id_str = str(tag_id)
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
|
||||
TagService.delete_tag(tag_id_str)
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from urllib import parse
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource
|
||||
@ -176,7 +175,7 @@ class MemberCancelInviteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id: UUID):
|
||||
def delete(self, member_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -209,7 +208,7 @@ class MemberUpdateRoleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, member_id: UUID):
|
||||
def put(self, member_id):
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberRoleUpdatePayload.model_validate(payload)
|
||||
new_role = args.role
|
||||
@ -352,7 +351,7 @@ class OwnerTransfer(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self, member_id: UUID):
|
||||
def post(self, member_id):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferPayload.model_validate(payload)
|
||||
|
||||
|
||||
@ -532,7 +532,7 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type: str):
|
||||
def get(self, model_type):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -50,8 +49,8 @@ class ImagePreviewApi(Resource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id: UUID):
|
||||
file_id_str = str(file_id)
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True))
|
||||
timestamp = args.timestamp
|
||||
@ -60,7 +59,7 @@ class ImagePreviewApi(Resource):
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService(db.engine).get_image_preview(
|
||||
file_id=file_id_str,
|
||||
file_id=file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
@ -92,14 +91,14 @@ class FilePreviewApi(Resource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id: UUID):
|
||||
file_id_str = str(file_id)
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
|
||||
file_id=file_id_str,
|
||||
file_id=file_id,
|
||||
timestamp=args.timestamp,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
@ -160,10 +159,10 @@ class WorkspaceWebappLogoApi(Resource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, workspace_id: UUID):
|
||||
workspace_id_str = str(workspace_id)
|
||||
def get(self, workspace_id):
|
||||
workspace_id = str(workspace_id)
|
||||
|
||||
custom_config = TenantService.get_custom_config(workspace_id_str)
|
||||
custom_config = TenantService.get_custom_config(workspace_id)
|
||||
webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None
|
||||
|
||||
if not webapp_logo_file_id:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -46,19 +45,17 @@ class ToolFileApi(Resource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id: UUID, extension: str):
|
||||
file_id_str = str(file_id)
|
||||
def get(self, file_id, extension):
|
||||
file_id = str(file_id)
|
||||
|
||||
args = ToolFileQuery.model_validate(request.args.to_dict())
|
||||
if not verify_tool_file_signature(
|
||||
file_id=file_id_str, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign
|
||||
):
|
||||
if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
tool_file_manager = ToolFileManager()
|
||||
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
|
||||
file_id_str,
|
||||
file_id,
|
||||
)
|
||||
|
||||
if not stream or not tool_file:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -79,10 +78,10 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, job_id: UUID, action: str):
|
||||
def get(self, app_model: App, job_id, action):
|
||||
"""Get the status of an annotation reply action job."""
|
||||
job_id_str = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{job_id_str}"
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job does not exist.")
|
||||
@ -90,10 +89,10 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ""
|
||||
if job_status == "error":
|
||||
app_annotation_error_key = f"{action}_app_annotation_error_{job_id_str}"
|
||||
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {"job_id": job_id_str, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations")
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -196,7 +195,7 @@ class ConversationDetailApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -225,7 +224,7 @@ class ConversationRenameApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -267,7 +266,7 @@ class ConversationVariablesApi(Resource):
|
||||
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
def get(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""List all variables for a conversation.
|
||||
|
||||
Conversational variables are only available for chat applications.
|
||||
@ -313,7 +312,7 @@ class ConversationVariableDetailApi(Resource):
|
||||
service_api_ns.models[ConversationVariableResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def put(self, app_model: App, end_user: EndUser, c_id: UUID, variable_id: UUID):
|
||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||
"""Update a conversation variable's value.
|
||||
|
||||
Allows updating the value of a specific conversation variable.
|
||||
@ -324,13 +323,13 @@ class ConversationVariableDetailApi(Resource):
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
variable_id_str = str(variable_id)
|
||||
variable_id = str(variable_id)
|
||||
|
||||
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
variable = ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id_str, end_user, payload.value
|
||||
app_model, conversation_id, variable_id, end_user, payload.value
|
||||
)
|
||||
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -95,19 +94,19 @@ class MessageFeedbackApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
def post(self, app_model: App, end_user: EndUser, message_id):
|
||||
"""Submit feedback for a message.
|
||||
|
||||
Allows users to rate messages as like/dislike and provide optional feedback content.
|
||||
"""
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id_str,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
@ -160,19 +159,19 @@ class MessageSuggestedApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
"""Get suggested follow-up questions for a message.
|
||||
|
||||
Returns AI-generated follow-up questions based on the message content.
|
||||
"""
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=end_user, message_id=message_id_str, invoke_from=InvokeFrom.SERVICE_API
|
||||
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
@ -337,7 +336,7 @@ class DatasetApi(DatasetApiResource):
|
||||
"Dataset retrieved successfully",
|
||||
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
def get(self, _, dataset_id: UUID):
|
||||
def get(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -404,7 +403,7 @@ class DatasetApi(DatasetApiResource):
|
||||
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, _, dataset_id: UUID):
|
||||
def patch(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -480,7 +479,7 @@ class DatasetApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, _, dataset_id: UUID):
|
||||
def delete(self, _, dataset_id):
|
||||
"""
|
||||
Deletes a dataset given its ID.
|
||||
|
||||
@ -535,7 +534,7 @@ class DocumentStatusApi(DatasetApiResource):
|
||||
400: "Bad request - invalid action",
|
||||
}
|
||||
)
|
||||
def patch(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
"""
|
||||
Batch update document status.
|
||||
|
||||
|
||||
@ -374,7 +374,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id: UUID):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
@ -395,6 +395,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
@ -585,17 +586,17 @@ class DocumentListApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
def get(self, tenant_id, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
query_params = DocumentListQuery.model_validate(request.args.to_dict())
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
query = select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == tenant_id)
|
||||
query = select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == tenant_id)
|
||||
|
||||
if query_params.status:
|
||||
query = DocumentService.apply_display_status_filter(query, query_params.status)
|
||||
@ -645,7 +646,7 @@ class DocumentBatchDownloadZipApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id: UUID):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
@ -680,17 +681,18 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
404: "Dataset or documents not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
def get(self, tenant_id, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
tenant_id = str(tenant_id)
|
||||
# get dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# get documents
|
||||
documents = DocumentService.get_batch_documents(dataset_id_str, batch)
|
||||
documents = DocumentService.get_batch_documents(dataset_id, batch)
|
||||
if not documents:
|
||||
raise NotFound("Documents not found.")
|
||||
documents_status = []
|
||||
@ -755,7 +757,7 @@ class DocumentDownloadApi(DatasetApiResource):
|
||||
service_api_ns.models[UrlResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def get(self, tenant_id, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
dataset = self.get_dataset(str(dataset_id), str(tenant_id))
|
||||
document = DocumentService.get_document(dataset.id, str(document_id))
|
||||
|
||||
@ -783,13 +785,13 @@ class DocumentApi(DatasetApiResource):
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = self.get_dataset(dataset_id_str, tenant_id)
|
||||
dataset = self.get_dataset(dataset_id, tenant_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
@ -806,15 +808,15 @@ class DocumentApi(DatasetApiResource):
|
||||
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
|
||||
if has_summary_index and document.need_summary is True:
|
||||
summary_index_status = SummaryIndexService.get_document_summary_index_status(
|
||||
document_id=document_id_str,
|
||||
dataset_id=dataset_id_str,
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
@ -849,7 +851,7 @@ class DocumentApi(DatasetApiResource):
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
@ -916,21 +918,21 @@ class DocumentApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id: UUID, document_id: UUID):
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id_str = str(document_id)
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||
from controllers.service_api import service_api_ns
|
||||
@ -22,7 +20,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id: UUID):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Perform hit testing on a dataset.
|
||||
|
||||
Tests retrieval performance for the specified dataset.
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask_login import current_user
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -58,7 +57,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
201, "Metadata created successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id: UUID):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create metadata for a dataset."""
|
||||
metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {})
|
||||
|
||||
@ -84,7 +83,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
@service_api_ns.response(
|
||||
200, "Metadata retrieved successfully", service_api_ns.models[DatasetMetadataListResponse.__name__]
|
||||
)
|
||||
def get(self, tenant_id, dataset_id: UUID):
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all metadata for a dataset."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -111,7 +110,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
200, "Metadata updated successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id: UUID, metadata_id: UUID):
|
||||
def patch(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Update metadata name."""
|
||||
payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
@ -137,7 +136,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
)
|
||||
@service_api_ns.response(204, "Metadata deleted successfully")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id: UUID, metadata_id: UUID):
|
||||
def delete(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Delete metadata."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
@ -165,7 +164,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
"Built-in fields retrieved successfully",
|
||||
service_api_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id, dataset_id: UUID):
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all built-in metadata fields."""
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
|
||||
@ -187,7 +186,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
200, "Action completed successfully", service_api_ns.models[DatasetMetadataActionResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable built-in metadata field."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -222,7 +221,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
service_api_ns.models[DatasetMetadataActionResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id: UUID):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Update metadata for multiple documents."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -140,7 +140,7 @@ class CompletionStopApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, app_model, end_user, task_id: str):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -226,7 +226,7 @@ class ChatStopApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, app_model, end_user, task_id: str):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
@ -127,7 +126,7 @@ class ConversationApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def delete(self, app_model, end_user, c_id: UUID):
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -166,7 +165,7 @@ class ConversationRenameApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user, c_id: UUID):
|
||||
def post(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -204,7 +203,7 @@ class ConversationPinApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__])
|
||||
def patch(self, app_model, end_user, c_id: UUID):
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -235,7 +234,7 @@ class ConversationUnPinApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__])
|
||||
def patch(self, app_model, end_user, c_id: UUID):
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
@ -133,15 +132,15 @@ class MessageFeedbackApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Feedback submitted successfully", web_ns.models[ResultResponse.__name__])
|
||||
def post(self, app_model, end_user, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
def post(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id_str,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
@ -167,11 +166,11 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user, message_id: UUID):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
raw_args = request.args.to_dict()
|
||||
query = MessageMoreLikeThisQuery.model_validate(raw_args)
|
||||
@ -182,7 +181,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
response = AppGenerateService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
message_id=message_id_str,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=streaming,
|
||||
)
|
||||
@ -223,16 +222,16 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user, message_id: UUID):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=end_user, message_id=message_id_str, invoke_from=InvokeFrom.WEB_APP
|
||||
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
|
||||
)
|
||||
# questions is a list of strings, not a list of Message objects
|
||||
except MessageNotExistsError:
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -106,12 +104,12 @@ class SavedMessageApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def delete(self, app_model, end_user, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
def delete(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
SavedMessageService.delete(app_model, end_user, message_id_str)
|
||||
SavedMessageService.delete(app_model, end_user, message_id)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import enum
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, cast, override
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel
|
||||
@ -18,7 +18,6 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
@override
|
||||
def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
@ -29,14 +28,12 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]):
|
||||
return value.hex
|
||||
return value
|
||||
|
||||
@override
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(UUID())
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(36))
|
||||
|
||||
@override
|
||||
def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
@ -47,13 +44,11 @@ class LongText(TypeDecorator[str | None]):
|
||||
impl = TEXT
|
||||
cache_ok = True
|
||||
|
||||
@override
|
||||
def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return value
|
||||
|
||||
@override
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(TEXT())
|
||||
@ -62,7 +57,6 @@ class LongText(TypeDecorator[str | None]):
|
||||
else:
|
||||
return dialect.type_descriptor(TEXT())
|
||||
|
||||
@override
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
@ -83,7 +77,6 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]):
|
||||
self._model_class = model_class
|
||||
super().__init__()
|
||||
|
||||
@override
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(TEXT())
|
||||
@ -92,7 +85,6 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]):
|
||||
else:
|
||||
return dialect.type_descriptor(TEXT())
|
||||
|
||||
@override
|
||||
def process_bind_param(self, value: T | dict[str, Any] | str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
@ -104,7 +96,6 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]):
|
||||
model = self._model_class.model_validate(value)
|
||||
return json.dumps(model.model_dump(mode="json"), ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
@override
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
|
||||
if value is None or value == "":
|
||||
return None
|
||||
@ -115,13 +106,11 @@ class BinaryData(TypeDecorator[bytes | None]):
|
||||
impl = LargeBinary
|
||||
cache_ok = True
|
||||
|
||||
@override
|
||||
def process_bind_param(self, value: bytes | None, dialect: Dialect) -> bytes | None:
|
||||
if value is None:
|
||||
return value
|
||||
return value
|
||||
|
||||
@override
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(BYTEA())
|
||||
@ -130,7 +119,6 @@ class BinaryData(TypeDecorator[bytes | None]):
|
||||
else:
|
||||
return dialect.type_descriptor(LargeBinary())
|
||||
|
||||
@override
|
||||
def process_result_value(self, value: bytes | None, dialect: Dialect) -> bytes | None:
|
||||
if value is None:
|
||||
return value
|
||||
@ -145,7 +133,6 @@ class AdjustedJSON(TypeDecorator[dict | list | None]):
|
||||
self.astext_type = astext_type
|
||||
super().__init__()
|
||||
|
||||
@override
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "postgresql":
|
||||
if self.astext_type:
|
||||
@ -157,13 +144,11 @@ class AdjustedJSON(TypeDecorator[dict | list | None]):
|
||||
else:
|
||||
return dialect.type_descriptor(sa.JSON())
|
||||
|
||||
@override
|
||||
def process_bind_param(
|
||||
self, value: dict[str, Any] | list[Any] | None, dialect: Dialect
|
||||
) -> dict[str, Any] | list[Any] | None:
|
||||
return value
|
||||
|
||||
@override
|
||||
def process_result_value(
|
||||
self, value: dict[str, Any] | list[Any] | None, dialect: Dialect
|
||||
) -> dict[str, Any] | list[Any] | None:
|
||||
@ -188,7 +173,6 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
|
||||
# leave some rooms for future longer enum values.
|
||||
self._length = max(max_enum_value_len, 20)
|
||||
|
||||
@override
|
||||
def process_bind_param(self, value: T | str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
@ -198,11 +182,9 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
|
||||
self._enum_class(value)
|
||||
return value
|
||||
|
||||
@override
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
return dialect.type_descriptor(VARCHAR(self._length))
|
||||
|
||||
@override
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
|
||||
if value is None or value == "":
|
||||
return None
|
||||
@ -215,7 +197,6 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
|
||||
return cast(T, value_of(value))
|
||||
raise
|
||||
|
||||
@override
|
||||
def compare_values(self, x: T | None, y: T | None) -> bool:
|
||||
if x is None or y is None:
|
||||
return x is y
|
||||
|
||||
@ -5,6 +5,7 @@ import { useSuspenseQuery } from '@tanstack/react-query'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import DifyLogo from '@/app/components/base/logo/dify-logo'
|
||||
import Link from '@/next/link'
|
||||
import { useRouter } from '@/next/navigation'
|
||||
import { systemFeaturesQueryOptions } from '@/service/system-features'
|
||||
import Avatar from './avatar'
|
||||
@ -17,21 +18,26 @@ const Header = () => {
|
||||
const goToStudio = useCallback(() => {
|
||||
router.push('/apps')
|
||||
}, [router])
|
||||
const logoLabel = systemFeatures.branding.enabled && systemFeatures.branding.application_title ? systemFeatures.branding.application_title : 'Dify'
|
||||
|
||||
return (
|
||||
<div className="flex flex-1 items-center justify-between px-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex cursor-pointer items-center" onClick={goToStudio}>
|
||||
<Link
|
||||
href="/apps"
|
||||
className="flex items-center rounded-sm hover:opacity-80 focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:outline-hidden"
|
||||
aria-label={logoLabel}
|
||||
>
|
||||
{systemFeatures.branding.enabled && systemFeatures.branding.login_page_logo
|
||||
? (
|
||||
<img
|
||||
src={systemFeatures.branding.login_page_logo}
|
||||
className="block h-[22px] w-auto object-contain"
|
||||
alt="Dify logo"
|
||||
alt=""
|
||||
/>
|
||||
)
|
||||
: <DifyLogo />}
|
||||
</div>
|
||||
: <DifyLogo alt="" />}
|
||||
</Link>
|
||||
<div className="h-4 w-px origin-center rotate-[11.31deg] bg-divider-regular" />
|
||||
<p className="relative mt-[-2px] title-3xl-semi-bold text-text-primary">{t('account.account', { ns: 'common' })}</p>
|
||||
</div>
|
||||
|
||||
@ -58,6 +58,12 @@ describe('DifyLogo', () => {
|
||||
const img = screen.getByRole('img', { name: /dify logo/i })
|
||||
expect(img).toHaveClass('custom-test-class')
|
||||
})
|
||||
|
||||
it('applies custom alt text', () => {
|
||||
const { container } = render(<DifyLogo alt="" />)
|
||||
const img = container.querySelector('img')
|
||||
expect(img).toHaveAttribute('alt', '')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Theme behavior', () => {
|
||||
|
||||
@ -23,12 +23,14 @@ type DifyLogoProps = {
|
||||
style?: LogoStyle
|
||||
size?: LogoSize
|
||||
className?: string
|
||||
alt?: string
|
||||
}
|
||||
|
||||
const DifyLogo: FC<DifyLogoProps> = ({
|
||||
style = 'default',
|
||||
size = 'medium',
|
||||
className,
|
||||
alt = 'Dify logo',
|
||||
}) => {
|
||||
const { theme } = useTheme()
|
||||
const themedStyle = (theme === 'dark' && style === 'default') ? 'monochromeWhite' : style
|
||||
@ -37,7 +39,7 @@ const DifyLogo: FC<DifyLogoProps> = ({
|
||||
<img
|
||||
src={`${basePath}${logoPathMap[themedStyle]}`}
|
||||
className={cn('block object-contain', logoSizeMap[size], className)}
|
||||
alt="Dify logo"
|
||||
alt={alt}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { ReactElement } from 'react'
|
||||
import type { AnchorHTMLAttributes, ReactElement } from 'react'
|
||||
import { fireEvent, screen } from '@testing-library/react'
|
||||
import { vi } from 'vitest'
|
||||
import { renderWithSystemFeatures } from '@/__tests__/utils/mock-system-features'
|
||||
@ -55,7 +55,7 @@ vi.mock('@/context/workspace-context-provider', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/next/link', () => ({
|
||||
default: ({ children, href }: { children?: React.ReactNode, href?: string }) => <a href={href}>{children}</a>,
|
||||
default: ({ children, href, ...props }: AnchorHTMLAttributes<HTMLAnchorElement> & { href?: string }) => <a href={href} {...props}>{children}</a>,
|
||||
}))
|
||||
|
||||
let mockIsWorkspaceEditor = false
|
||||
@ -122,7 +122,9 @@ describe('Header', () => {
|
||||
it('should render header with main nav components', () => {
|
||||
renderHeader()
|
||||
|
||||
expect(screen.getByRole('img', { name: /dify logo/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: 'Dify' })).toHaveAttribute('href', '/apps')
|
||||
expect(screen.queryByRole('heading', { level: 1 })).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('img', { name: /dify logo/i })).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('workplace-selector')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('app-nav')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('account-dropdown')).toBeInTheDocument()
|
||||
@ -166,7 +168,7 @@ describe('Header', () => {
|
||||
mockMedia = 'mobile'
|
||||
renderHeader()
|
||||
|
||||
expect(screen.getByRole('img', { name: /dify logo/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: 'Dify' })).toHaveAttribute('href', '/apps')
|
||||
expect(screen.queryByTestId('env-nav')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
@ -177,8 +179,8 @@ describe('Header', () => {
|
||||
|
||||
renderHeader()
|
||||
|
||||
expect(screen.getByText('Acme Workspace')).toBeInTheDocument()
|
||||
expect(screen.getByRole('img', { name: /logo/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: 'Acme Workspace' })).toHaveAttribute('href', '/apps')
|
||||
expect(screen.queryByRole('img', { name: /logo/i })).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('img', { name: /dify logo/i })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
@ -189,18 +191,18 @@ describe('Header', () => {
|
||||
|
||||
renderHeader()
|
||||
|
||||
expect(screen.getByText('Custom Title')).toBeInTheDocument()
|
||||
expect(screen.getByRole('img', { name: /dify logo/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: 'Custom Title' })).toHaveAttribute('href', '/apps')
|
||||
expect(screen.queryByRole('img', { name: /dify logo/i })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show default Dify text when branding enabled but no application_title', () => {
|
||||
it('should use default Dify link label when branding enabled but no application_title', () => {
|
||||
mockBrandingEnabled = true
|
||||
mockBrandingTitle = null
|
||||
mockBrandingLogo = null
|
||||
|
||||
renderHeader()
|
||||
|
||||
expect(screen.getByText('Dify')).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: 'Dify' })).toHaveAttribute('href', '/apps')
|
||||
})
|
||||
|
||||
it('should show dataset nav for editor who is not dataset operator', () => {
|
||||
|
||||
@ -44,21 +44,23 @@ const Header = () => {
|
||||
setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.BILLING })
|
||||
}, [isFreePlan, setShowAccountSettingModal, setShowPricingModal])
|
||||
|
||||
const logoLabel = isBrandingEnabled && systemFeatures.branding.application_title ? systemFeatures.branding.application_title : 'Dify'
|
||||
const renderLogo = () => (
|
||||
<h1>
|
||||
<Link href="/apps" className="flex h-8 shrink-0 items-center justify-center overflow-hidden px-0.5 indent-[-9999px] whitespace-nowrap">
|
||||
{isBrandingEnabled && systemFeatures.branding.application_title ? systemFeatures.branding.application_title : 'Dify'}
|
||||
{systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo
|
||||
? (
|
||||
<img
|
||||
src={systemFeatures.branding.workspace_logo}
|
||||
className="block h-[22px] w-auto object-contain"
|
||||
alt="logo"
|
||||
/>
|
||||
)
|
||||
: <DifyLogo />}
|
||||
</Link>
|
||||
</h1>
|
||||
<Link
|
||||
href="/apps"
|
||||
className="flex h-8 shrink-0 items-center justify-center overflow-hidden rounded-sm px-0.5 hover:opacity-80 focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:outline-hidden"
|
||||
aria-label={logoLabel}
|
||||
>
|
||||
{systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo
|
||||
? (
|
||||
<img
|
||||
src={systemFeatures.branding.workspace_logo}
|
||||
className="block h-[22px] w-auto object-contain"
|
||||
alt=""
|
||||
/>
|
||||
)
|
||||
: <DifyLogo alt="" />}
|
||||
</Link>
|
||||
)
|
||||
|
||||
if (isMobile) {
|
||||
|
||||
Reference in New Issue
Block a user