mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge branch main into feat/rag-2
This commit is contained in:
@ -178,7 +178,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
# get tenant api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
@ -160,7 +160,7 @@ class ToolFileManager:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.filter(
|
||||
.where(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
@ -184,7 +184,7 @@ class ToolFileManager:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.filter(
|
||||
.where(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
@ -204,7 +204,7 @@ class ToolFileManager:
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.filter(
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
@ -228,7 +228,7 @@ class ToolFileManager:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.filter(
|
||||
.where(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
|
||||
@ -29,7 +29,7 @@ class ToolLabelManager:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
# delete old labels
|
||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
|
||||
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
@ -57,7 +57,7 @@ class ToolLabelManager:
|
||||
|
||||
labels = (
|
||||
db.session.query(ToolLabelBinding.label_name)
|
||||
.filter(
|
||||
.where(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
)
|
||||
@ -90,7 +90,7 @@ class ToolLabelManager:
|
||||
provider_ids.append(controller.provider_id)
|
||||
|
||||
labels: list[ToolLabelBinding] = (
|
||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
|
||||
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
|
||||
)
|
||||
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
@ -1,16 +1,19 @@
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@ -195,7 +198,7 @@ class ToolManager:
|
||||
try:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
@ -213,7 +216,7 @@ class ToolManager:
|
||||
# use the default provider
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
@ -226,7 +229,7 @@ class ToolManager:
|
||||
else:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
@ -244,12 +247,47 @@ class ToolManager:
|
||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
# decrypt the credentials
|
||||
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
|
||||
|
||||
# check if the credentials is expired
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
# refresh the credentials
|
||||
tool_provider = ToolProviderID(provider_id)
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
oauth_handler = OAuthHandler()
|
||||
# refresh the credentials
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=builtin_provider.user_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=decrypted_credentials,
|
||||
)
|
||||
# update the credentials
|
||||
builtin_provider.encrypted_credentials = (
|
||||
TypeAdapter(dict[str, Any])
|
||||
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
|
||||
.decode("utf-8")
|
||||
)
|
||||
builtin_provider.expires_at = refreshed_credentials.expires_at
|
||||
db.session.commit()
|
||||
decrypted_credentials = refreshed_credentials.credentials
|
||||
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
builtin_tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=encrypter.decrypt(builtin_provider.credentials),
|
||||
credentials=dict(decrypted_credentials),
|
||||
credential_type=CredentialType.of(builtin_provider.credential_type),
|
||||
runtime_parameters={},
|
||||
invoke_from=invoke_from,
|
||||
@ -278,7 +316,7 @@ class ToolManager:
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
workflow_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -578,7 +616,7 @@ class ToolManager:
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
|
||||
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
@classmethod
|
||||
def list_providers_from_api(
|
||||
@ -626,7 +664,7 @@ class ToolManager:
|
||||
# get db api providers
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
@ -649,7 +687,7 @@ class ToolManager:
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
@ -693,7 +731,7 @@ class ToolManager:
|
||||
"""
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
ApiToolProvider.id == provider_id,
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
@ -730,7 +768,7 @@ class ToolManager:
|
||||
"""
|
||||
provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
MCPToolProvider.server_identifier == provider_id,
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
@ -755,7 +793,7 @@ class ToolManager:
|
||||
provider_name = provider
|
||||
provider_obj: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
)
|
||||
@ -847,7 +885,7 @@ class ToolManager:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -864,7 +902,7 @@ class ToolManager:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -881,7 +919,7 @@ class ToolManager:
|
||||
try:
|
||||
mcp_provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -973,7 +1011,9 @@ class ToolManager:
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
elif tool_input.type == "constant":
|
||||
parameter_value = tool_input.value
|
||||
elif tool_input.type == "mixed":
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.text
|
||||
else:
|
||||
|
||||
@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
.where(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
.where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
):
|
||||
with flask_app.app_context():
|
||||
dataset = (
|
||||
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
|
||||
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from msal_extensions.persistence import ABC # type: ignore
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@ -21,11 +21,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
def _run(self, query: str) -> str:
|
||||
"""Use the tool.
|
||||
|
||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
|
||||
@ -57,7 +57,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
dataset = (
|
||||
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
|
||||
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
@ -190,7 +190,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument) # type: ignore
|
||||
.filter(
|
||||
.where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
|
||||
@ -84,7 +84,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
"""
|
||||
workflow: Workflow | None = (
|
||||
db.session.query(Workflow)
|
||||
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -190,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
db_providers: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
)
|
||||
|
||||
@ -142,12 +142,12 @@ class WorkflowTool(Tool):
|
||||
if not version:
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(Workflow.app_id == app_id, Workflow.version != "draft")
|
||||
.where(Workflow.app_id == app_id, Workflow.version != "draft")
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first()
|
||||
workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found or not published")
|
||||
@ -158,7 +158,7 @@ class WorkflowTool(Tool):
|
||||
"""
|
||||
get the app by app id
|
||||
"""
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user