mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
Merge branch 'main' into feat/mcp
This commit is contained in:
@ -6,12 +6,12 @@ from sqlalchemy.orm import Session
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunStatus
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class WorkflowAppLogApi(Resource):
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
|
||||
@ -2,12 +2,14 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import reqparse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
from services.account_service import AccountService
|
||||
@ -80,7 +82,12 @@ def get_user_tenant(view: Optional[Callable] = None):
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
|
||||
user = get_user(tenant_id, user_id)
|
||||
kwargs["user_model"] = user
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
|
||||
@ -9,13 +9,13 @@ from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
)
|
||||
from libs.login import current_user
|
||||
from models.model import App, EndUser
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, end_user: EndUser, action):
|
||||
def post(self, app_model: App, action):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
@ -32,7 +32,7 @@ class AnnotationReplyActionApi(Resource):
|
||||
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, end_user: EndUser, job_id, action):
|
||||
def get(self, app_model: App, job_id, action):
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
@ -50,7 +50,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
def get(self, app_model: App):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
@ -67,7 +67,7 @@ class AnnotationListApi(Resource):
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
def post(self, app_model: App):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
@ -79,7 +79,7 @@ class AnnotationListApi(Resource):
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@validate_app_token
|
||||
@marshal_with(annotation_fields)
|
||||
def put(self, app_model: App, end_user: EndUser, annotation_id):
|
||||
def put(self, app_model: App, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -92,7 +92,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
return annotation
|
||||
|
||||
@validate_app_token
|
||||
def delete(self, app_model: App, end_user: EndUser, annotation_id):
|
||||
def delete(self, app_model: App, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
@ -24,12 +24,13 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
@ -138,7 +139,7 @@ class WorkflowAppLogApi(Resource):
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
|
||||
@ -1,19 +1,21 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from flask_restful import marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
@ -320,5 +322,135 @@ class DatasetApi(DatasetApiResource):
|
||||
raise DatasetInUseError()
|
||||
|
||||
|
||||
class DatasetTagsApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
@marshal_with(tag_fields)
|
||||
def get(self, _, dataset_id):
|
||||
"""Get all knowledge type tags."""
|
||||
tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=DatasetTagsApi._validate_tag_name,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=DatasetTagsApi._validate_tag_name,
|
||||
)
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
return response, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
|
||||
return 204
|
||||
|
||||
@staticmethod
|
||||
def _validate_tag_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class DatasetTagBindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
||||
)
|
||||
parser.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
return response, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, "/datasets")
|
||||
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
|
||||
api.add_resource(DatasetTagsApi, "/datasets/tags")
|
||||
api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
|
||||
api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
|
||||
api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags")
|
||||
|
||||
@ -175,8 +175,11 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
if not dataset.indexing_technique and not args.get("indexing_technique"):
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
if not indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
args["indexing_technique"] = indexing_technique
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
@ -206,12 +209,16 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
|
||||
if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule:
|
||||
raise ValueError("process_rule is required.")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
dataset_process_rule=dataset_process_rule,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@ -208,6 +208,28 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
class ChildChunkApi(DatasetApiResource):
|
||||
"""Resource for child chunks."""
|
||||
|
||||
Reference in New Issue
Block a user