mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
Merge branch 'main' into feat/mcp-06-18
This commit is contained in:
@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
|
||||
@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
|
||||
content_text=tool_parameters.get("text"), # type: ignore
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
voice=voice, # type: ignore
|
||||
voice=voice,
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts:
|
||||
|
||||
@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
||||
|
||||
yield self.create_text_message(f"{timestamp}")
|
||||
|
||||
# TODO: this method's type is messy
|
||||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
try:
|
||||
|
||||
@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
|
||||
datetime_with_tz = input_timezone.localize(local_time)
|
||||
# timezone convert
|
||||
converted_datetime = datetime_with_tz.astimezone(output_timezone)
|
||||
return converted_datetime.strftime(format=time_format) # type: ignore
|
||||
return converted_datetime.strftime(time_format)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
||||
@ -113,7 +113,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
|
||||
def get_tool(self, tool_name: str) -> MCPTool:
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
@ -136,7 +136,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
||||
def get_tools(self) -> list[MCPTool]:
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
|
||||
@ -26,7 +26,7 @@ class ToolLabelManager:
|
||||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
||||
provider_id = controller.provider_id
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
@ -51,7 +51,7 @@ class ToolLabelManager:
|
||||
Get tool labels
|
||||
"""
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
return controller.tool_labels
|
||||
else:
|
||||
@ -85,7 +85,7 @@ class ToolLabelManager:
|
||||
provider_ids = []
|
||||
for controller in tool_providers:
|
||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
|
||||
provider_ids.append(controller.provider_id)
|
||||
|
||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||
|
||||
|
||||
@ -331,7 +331,8 @@ class ToolManager:
|
||||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||
)
|
||||
workflow_provider = db.session.scalar(workflow_provider_stmt)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
workflow_provider = session.scalar(workflow_provider_stmt)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt) # type: ignore
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id, # type: ignore
|
||||
document_name=document.name, # type: ignore
|
||||
data_source_type=document.data_source_type, # type: ignore
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=self.retriever_from,
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata, # type: ignore
|
||||
doc_metadata=document.doc_metadata,
|
||||
)
|
||||
|
||||
if self.retriever_from == "dev":
|
||||
|
||||
@ -62,6 +62,11 @@ class ApiBasedToolSchemaParser:
|
||||
root = root[ref]
|
||||
interface["operation"]["parameters"][i] = root
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
# Handle complex type defaults that are not supported by PluginParameter
|
||||
default_value = None
|
||||
if "schema" in parameter and "default" in parameter["schema"]:
|
||||
default_value = ApiBasedToolSchemaParser._sanitize_default_value(parameter["schema"]["default"])
|
||||
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
@ -72,9 +77,7 @@ class ApiBasedToolSchemaParser:
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
default=default_value,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
@ -134,6 +137,11 @@ class ApiBasedToolSchemaParser:
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
# Handle complex type defaults that are not supported by PluginParameter
|
||||
default_value = ApiBasedToolSchemaParser._sanitize_default_value(
|
||||
property.get("default", None)
|
||||
)
|
||||
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
@ -144,12 +152,11 @@ class ApiBasedToolSchemaParser:
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
default=default_value,
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
@ -197,6 +204,22 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_default_value(value):
|
||||
"""
|
||||
Sanitize default values for PluginParameter compatibility.
|
||||
Complex types (list, dict) are converted to None to avoid validation errors.
|
||||
|
||||
Args:
|
||||
value: The default value from OpenAPI schema
|
||||
|
||||
Returns:
|
||||
None for complex types (list, dict), otherwise the original value
|
||||
"""
|
||||
if isinstance(value, (list, dict)):
|
||||
return None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
@ -217,7 +240,11 @@ class ApiBasedToolSchemaParser:
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||
if items and items.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILES
|
||||
else:
|
||||
# For regular arrays, return ARRAY type instead of None
|
||||
return ToolParameter.ToolParameterType.ARRAY
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@ -6,8 +6,8 @@ from typing import Any, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper # type: ignore
|
||||
from readabilipy import simple_json_from_html_string # type: ignore
|
||||
import cloudscraper
|
||||
from readabilipy import simple_json_from_html_string
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = ssrf_proxy.make_request # type: ignore
|
||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
|
||||
scraper.perform_request = ssrf_proxy.make_request
|
||||
response = scraper.get(url, headers=headers, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
return f"URL returned status code {response.status_code}."
|
||||
|
||||
@ -3,7 +3,7 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
import yaml
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
@ -20,6 +21,7 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
@ -44,29 +46,34 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
||||
app = db_provider.app
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
provider = session.get(WorkflowToolProvider, db_provider.id) if db_provider.id else None
|
||||
if not provider:
|
||||
raise ValueError("workflow provider not found")
|
||||
app = session.get(App, provider.app_id)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
user = session.get(Account, provider.user_id) if provider.user_id else None
|
||||
|
||||
controller = WorkflowToolProviderController(
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author=db_provider.user.name if db_provider.user_id and db_provider.user else "",
|
||||
name=db_provider.label,
|
||||
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
icon=db_provider.icon,
|
||||
controller = WorkflowToolProviderController(
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user.name if user else "",
|
||||
name=provider.label,
|
||||
label=I18nObject(en_US=provider.label, zh_Hans=provider.label),
|
||||
description=I18nObject(en_US=provider.description, zh_Hans=provider.description),
|
||||
icon=provider.icon,
|
||||
),
|
||||
credentials_schema=[],
|
||||
plugin_id=None,
|
||||
),
|
||||
credentials_schema=[],
|
||||
plugin_id=None,
|
||||
),
|
||||
provider_id=db_provider.id or "",
|
||||
)
|
||||
provider_id=provider.id or "",
|
||||
)
|
||||
|
||||
# init tools
|
||||
|
||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
||||
controller.tools = [
|
||||
controller._get_db_provider_tool(provider, app, session=session, user=user),
|
||||
]
|
||||
|
||||
return controller
|
||||
|
||||
@ -74,7 +81,14 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
||||
def _get_db_provider_tool(
|
||||
self,
|
||||
db_provider: WorkflowToolProvider,
|
||||
app: App,
|
||||
*,
|
||||
session: Session,
|
||||
user: Account | None = None,
|
||||
) -> WorkflowTool:
|
||||
"""
|
||||
get db provider tool
|
||||
:param db_provider: the db provider
|
||||
@ -82,7 +96,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
:return: the tool
|
||||
"""
|
||||
workflow: Workflow | None = (
|
||||
db.session.query(Workflow)
|
||||
session.query(Workflow)
|
||||
.where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||
.first()
|
||||
)
|
||||
@ -99,9 +113,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
|
||||
|
||||
user = db_provider.user
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
|
||||
workflow_tool_parameters = []
|
||||
for parameter in parameters:
|
||||
@ -187,22 +199,25 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
db_providers: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
db_provider: WorkflowToolProvider | None = (
|
||||
session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
if not db_providers.app:
|
||||
raise ValueError("app not found")
|
||||
if not db_provider:
|
||||
return []
|
||||
|
||||
app = db_providers.app
|
||||
self.tools = [self._get_db_provider_tool(db_providers, app)]
|
||||
app = session.get(App, db_provider.app_id)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
|
||||
self.tools = [self._get_db_provider_tool(db_provider, app, session=session, user=user)]
|
||||
|
||||
return self.tools
|
||||
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import has_request_context
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
@ -48,6 +50,7 @@ class WorkflowTool(Tool):
|
||||
self.workflow_entities = workflow_entities
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.label = label
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
@ -83,10 +86,11 @@ class WorkflowTool(Tool):
|
||||
assert self.runtime.invoke_from is not None
|
||||
|
||||
user = self._resolve_user(user_id=user_id)
|
||||
|
||||
if user is None:
|
||||
raise ToolInvokeError("User not found")
|
||||
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
@ -110,9 +114,68 @@ class WorkflowTool(Tool):
|
||||
for file in files:
|
||||
yield self.create_file_message(file) # type: ignore
|
||||
|
||||
self._latest_usage = self._derive_usage_from_result(data)
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
yield self.create_json_message(outputs)
|
||||
|
||||
@property
|
||||
def latest_usage(self) -> LLMUsage:
|
||||
return self._latest_usage
|
||||
|
||||
@classmethod
|
||||
def _derive_usage_from_result(cls, data: Mapping[str, Any]) -> LLMUsage:
|
||||
usage_dict = cls._extract_usage_dict(data)
|
||||
if usage_dict is not None:
|
||||
return LLMUsage.from_metadata(cast(LLMUsageMetadata, dict(usage_dict)))
|
||||
|
||||
total_tokens = data.get("total_tokens")
|
||||
total_price = data.get("total_price")
|
||||
if total_tokens is None and total_price is None:
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
usage_metadata: dict[str, Any] = {}
|
||||
if total_tokens is not None:
|
||||
try:
|
||||
usage_metadata["total_tokens"] = int(str(total_tokens))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if total_price is not None:
|
||||
usage_metadata["total_price"] = str(total_price)
|
||||
currency = data.get("currency")
|
||||
if currency is not None:
|
||||
usage_metadata["currency"] = currency
|
||||
|
||||
if not usage_metadata:
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
return LLMUsage.from_metadata(cast(LLMUsageMetadata, usage_metadata))
|
||||
|
||||
@classmethod
|
||||
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
|
||||
usage_candidate = payload.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
metadata_candidate = payload.get("metadata")
|
||||
if isinstance(metadata_candidate, Mapping):
|
||||
usage_candidate = metadata_candidate.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
for value in payload.values():
|
||||
if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
|
||||
"""
|
||||
fork a new tool with metadata
|
||||
@ -179,16 +242,17 @@ class WorkflowTool(Tool):
|
||||
"""
|
||||
get the workflow by app id and version
|
||||
"""
|
||||
if not version:
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
|
||||
workflow = db.session.scalar(stmt)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
if not version:
|
||||
stmt = (
|
||||
select(Workflow)
|
||||
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
)
|
||||
workflow = session.scalars(stmt).first()
|
||||
else:
|
||||
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found or not published")
|
||||
@ -200,7 +264,8 @@ class WorkflowTool(Tool):
|
||||
get the app by app id
|
||||
"""
|
||||
stmt = select(App).where(App.id == app_id)
|
||||
app = db.session.scalar(stmt)
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user