Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

# Conflicts:
#	api/core/memory/token_buffer_memory.py
#	api/core/rag/extractor/notion_extractor.py
#	api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
#	api/core/variables/variables.py
#	api/core/workflow/graph/graph.py
#	api/core/workflow/graph_engine/entities/event.py
#	api/services/dataset_service.py
#	web/app/components/app-sidebar/index.tsx
#	web/app/components/base/tag-management/selector.tsx
#	web/app/components/base/toast/index.tsx
#	web/app/components/datasets/create/website/index.tsx
#	web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx
#	web/app/components/workflow/header/version-history-button.tsx
#	web/app/components/workflow/hooks/use-inspect-vars-crud-common.ts
#	web/app/components/workflow/hooks/use-workflow-interactions.ts
#	web/app/components/workflow/panel/version-history-panel/index.tsx
#	web/service/base.ts
This commit is contained in:
jyong
2025-09-03 15:01:06 +08:00
572 changed files with 16030 additions and 7973 deletions

View File

@ -130,15 +130,19 @@ class InsertExploreAppApi(Resource):
app.is_public = False
with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
installed_apps = (
session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
)
).all()
.scalars()
.all()
)
for installed_app in installed_apps:
db.session.delete(installed_app)
for installed_app in installed_apps:
session.delete(installed_app)
db.session.delete(recommended_app)
db.session.commit()

View File

@ -84,7 +84,7 @@ class BaseApiKeyListResource(Resource):
flask_restx.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
custom="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)

View File

@ -237,9 +237,14 @@ class AppExportApi(Resource):
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
parser.add_argument("workflow_id", type=str, location="args")
args = parser.parse_args()
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
return {
"data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
)
}
class AppNameApi(Resource):

View File

@ -130,7 +130,7 @@ class MessageFeedbackApi(Resource):
message_id = str(args["message_id"])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")

View File

@ -532,7 +532,7 @@ class PublishedWorkflowApi(Resource):
)
app_model.workflow_id = workflow.id
db.session.commit()
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
workflow_created_at = TimestampField().format(workflow.created_at)

View File

@ -27,7 +27,9 @@ class WorkflowAppLogApi(Resource):
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)

View File

@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400
try:
oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e:
except requests.HTTPError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)
@ -104,7 +104,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400
try:
oauth_provider.sync_data_source(binding_id)
except requests.exceptions.HTTPError as e:
except requests.HTTPError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)

View File

@ -130,7 +130,7 @@ class ResetPasswordSendEmailApi(Resource):
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
@ -162,7 +162,7 @@ class EmailCodeLoginSendEmailApi(Resource):
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
@ -200,7 +200,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args["token"])
try:
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
@ -223,7 +223,7 @@ class EmailCodeLoginApi(Resource):
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
except AccountRegisterError as are:
except AccountRegisterError:
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()

View File

@ -80,7 +80,7 @@ class OAuthCallback(Resource):
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.RequestException as e:
except requests.RequestException as e:
error_text = e.response.text if e.response else str(e)
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400

View File

@ -44,22 +44,19 @@ def oauth_server_access_token_required(view):
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
if not request.headers.get("Authorization"):
raise BadRequest("Authorization is required")
authorization_header = request.headers.get("Authorization")
if not authorization_header:
raise BadRequest("Authorization header is required")
parts = authorization_header.split(" ")
parts = authorization_header.strip().split(" ")
if len(parts) != 2:
raise BadRequest("Invalid Authorization header format")
token_type = parts[0]
if token_type != "Bearer":
token_type = parts[0].strip()
if token_type.lower() != "bearer":
raise BadRequest("token_type is invalid")
access_token = parts[1]
access_token = parts[1].strip()
if not access_token:
raise BadRequest("access_token is required")
@ -125,7 +122,10 @@ class OAuthServerUserTokenApi(Resource):
parser.add_argument("refresh_token", type=str, required=False, location="json")
parsed_args = parser.parse_args()
grant_type = OAuthGrantType(parsed_args["grant_type"])
try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
@ -163,8 +163,6 @@ class OAuthServerUserTokenApi(Resource):
"refresh_token": refresh_token,
}
)
else:
raise BadRequest("invalid grant_type")
class OAuthServerUserAccountApi(Resource):

View File

@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
@ -248,7 +249,7 @@ class DataSourceNotionApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,

View File

@ -21,6 +21,7 @@ from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@ -431,7 +432,9 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import":
@ -441,7 +444,7 @@ class DatasetIndexingEstimateApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
@ -456,7 +459,7 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],

View File

@ -41,6 +41,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from fields.document_fields import (
@ -356,9 +357,6 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
@ -430,7 +428,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
)
indexing_runner = IndexingRunner()
@ -490,13 +488,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
extract_setting = ExtractSetting(
datasource_type="notion_import",
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
@ -509,7 +507,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
extract_setting = ExtractSetting(
datasource_type="website_crawl",
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],

View File

@ -61,7 +61,6 @@ class ConversationApi(InstalledAppResource):
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204

View File

@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()

View File

@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource):
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
config_from=args.get("config_from", ""),
)
if args.get("config_from", "") == "predefined-model":
@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource):
choices=[mt.value for mt in ModelType],
location="json",
)
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource):
)
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()

View File

@ -1,8 +1,12 @@
from base64 import b64encode
from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@ -10,9 +14,9 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only(view):
def billing_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
return decorated
def enterprise_inner_api_only(view):
def enterprise_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
return decorated
def plugin_inner_api_only(view):
def plugin_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

View File

@ -55,7 +55,7 @@ class AudioApi(Resource):
file = request.files["file"]
try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:

View File

@ -59,7 +59,7 @@ class FilePreviewApi(Resource):
args = file_preview_parser.parse_args()
# Validate file ownership and get file objects
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
# Get file content generator
try:

View File

@ -413,7 +413,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,

View File

@ -1,7 +1,7 @@
import time
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from enum import StrEnum, auto
from functools import wraps
from typing import Optional
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
class WhereisUserArg(Enum):
class WhereisUserArg(StrEnum):
"""
Enum for whereis_user_arg.
"""
QUERY = "query"
JSON = "json"
FORM = "form"
QUERY = auto()
JSON = auto()
FORM = auto()
class FetchUserArg(BaseModel):
@ -291,27 +291,28 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
if not user_id:
user_id = "DEFAULT-USER"
end_user = (
db.session.query(EndUser)
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
session.query(EndUser)
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
)
.first()
)
.first()
)
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id,
)
db.session.add(end_user)
db.session.commit()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == "DEFAULT-USER",
session_id=user_id,
)
session.add(end_user)
session.commit()
return end_user

View File

@ -73,8 +73,6 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}, 204

View File

@ -4,6 +4,7 @@ from functools import wraps
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@ -49,18 +50,19 @@ def decode_jwt_token():
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
app_model = db.session.scalar(select(App).where(App.id == app_id))
site = db.session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()
with Session(db.engine, expire_on_commit=False) as session:
app_model = session.scalar(select(App).where(App.id == app_id))
site = session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()
# for enterprise webapp auth
app_web_auth_enabled = False