Merge branch main into feat/rag-2

This commit is contained in:
twwu
2025-07-24 17:40:04 +08:00
608 changed files with 8175 additions and 3026 deletions

View File

@ -178,7 +178,7 @@ class ApiToolProviderController(ToolProviderController):
# get tenant api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all()
)

View File

@ -160,7 +160,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.filter(
.where(
ToolFile.id == id,
)
.first()
@ -184,7 +184,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
message_file: MessageFile | None = (
session.query(MessageFile)
.filter(
.where(
MessageFile.id == id,
)
.first()
@ -204,7 +204,7 @@ class ToolFileManager:
tool_file: ToolFile | None = (
session.query(ToolFile)
.filter(
.where(
ToolFile.id == tool_file_id,
)
.first()
@ -228,7 +228,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.filter(
.where(
ToolFile.id == tool_file_id,
)
.first()

View File

@ -29,7 +29,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")
# delete old labels
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
# insert new labels
for label in labels:
@ -57,7 +57,7 @@ class ToolLabelManager:
labels = (
db.session.query(ToolLabelBinding.label_name)
.filter(
.where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
@ -90,7 +90,7 @@ class ToolLabelManager:
provider_ids.append(controller.provider_id)
labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}

View File

@ -1,16 +1,19 @@
import json
import logging
import mimetypes
from collections.abc import Generator
import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from pydantic import TypeAdapter
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -195,7 +198,7 @@ class ToolManager:
try:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
@ -213,7 +216,7 @@ class ToolManager:
# use the default provider
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
@ -226,7 +229,7 @@ class ToolManager:
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
@ -244,12 +247,47 @@ class ToolManager:
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=encrypter.decrypt(builtin_provider.credentials),
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
@ -278,7 +316,7 @@ class ToolManager:
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
@ -578,7 +616,7 @@ class ToolManager:
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
@classmethod
def list_providers_from_api(
@ -626,7 +664,7 @@ class ToolManager:
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
)
api_provider_controllers: list[dict[str, Any]] = [
@ -649,7 +687,7 @@ class ToolManager:
if "workflow" in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_provider_controllers: list[WorkflowToolProviderController] = []
@ -693,7 +731,7 @@ class ToolManager:
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id,
)
@ -730,7 +768,7 @@ class ToolManager:
"""
provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.filter(
.where(
MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
@ -755,7 +793,7 @@ class ToolManager:
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
@ -847,7 +885,7 @@ class ToolManager:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
@ -864,7 +902,7 @@ class ToolManager:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
@ -881,7 +919,7 @@ class ToolManager:
try:
mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first()
)
@ -973,7 +1011,9 @@ class ToolManager:
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
elif tool_input.type == "constant":
parameter_value = tool_input.value
elif tool_input.type == "mixed":
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.text
else:

View File

@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.filter(
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
):
with flask_app.app_context():
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
if not dataset:

View File

@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, Optional
from typing import Optional
from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
@ -21,11 +21,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def _run(self, query: str) -> str:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None

View File

@ -57,7 +57,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
def _run(self, query: str) -> str:
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
if not dataset:
@ -190,7 +190,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument) # type: ignore
.filter(
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,

View File

@ -84,7 +84,7 @@ class WorkflowToolProviderController(ToolProviderController):
"""
workflow: Workflow | None = (
db.session.query(Workflow)
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
)
@ -190,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
db_providers: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
)

View File

@ -142,12 +142,12 @@ class WorkflowTool(Tool):
if not version:
workflow = (
db.session.query(Workflow)
.filter(Workflow.app_id == app_id, Workflow.version != "draft")
.where(Workflow.app_id == app_id, Workflow.version != "draft")
.order_by(Workflow.created_at.desc())
.first()
)
else:
workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first()
workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
if not workflow:
raise ValueError("workflow not found or not published")
@ -158,7 +158,7 @@ class WorkflowTool(Tool):
"""
get the app by app id
"""
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError("app not found")