mirror of
https://github.com/langgenius/dify.git
synced 2026-02-23 03:17:57 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -1,5 +1,8 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
@ -24,7 +27,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
app = cls._get_app(app_id, tenant_id)
|
||||
|
||||
"""Retrieve app parameters."""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
@ -50,8 +53,8 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str],
|
||||
query: Optional[str],
|
||||
conversation_id: str | None,
|
||||
query: str | None,
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
files: list[dict],
|
||||
@ -67,7 +70,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
|
||||
conversation_id = conversation_id or ""
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}:
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
|
||||
if not query:
|
||||
raise ValueError("missing query")
|
||||
|
||||
@ -93,7 +96,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke chat app
|
||||
"""
|
||||
if app.mode == AppMode.ADVANCED_CHAT.value:
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
@ -111,7 +114,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.AGENT_CHAT.value:
|
||||
elif app.mode == AppMode.AGENT_CHAT:
|
||||
return AgentChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
@ -124,7 +127,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
)
|
||||
elif app.mode == AppMode.CHAT.value:
|
||||
elif app.mode == AppMode.CHAT:
|
||||
return ChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
user=user,
|
||||
@ -164,7 +167,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
call_depth=1,
|
||||
workflow_thread_pool_id=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -192,10 +194,12 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
|
||||
if not user:
|
||||
user = db.session.query(Account).where(Account.id == user_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(EndUser).where(EndUser.id == user_id)
|
||||
user = session.scalar(stmt)
|
||||
if not user:
|
||||
stmt = select(Account).where(Account.id == user_id)
|
||||
user = session.scalar(stmt)
|
||||
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Generic, Optional, TypeVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -23,5 +23,5 @@ T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel)
|
||||
|
||||
|
||||
class BaseBackwardsInvocationResponse(BaseModel, Generic[T]):
|
||||
data: Optional[T] = None
|
||||
data: T | None = None
|
||||
error: str = ""
|
||||
|
||||
@ -6,7 +6,7 @@ from models.account import Tenant
|
||||
|
||||
class PluginEncrypter:
|
||||
@classmethod
|
||||
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
|
||||
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt):
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant.id,
|
||||
config=payload.config,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
@ -27,7 +27,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
model_config: ParameterExtractorModelConfig,
|
||||
instruction: str,
|
||||
query: str,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Invoke parameter extractor node.
|
||||
|
||||
@ -78,7 +78,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
classes: list[ClassConfig],
|
||||
instruction: str,
|
||||
query: str,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Invoke question classifier node.
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
@ -23,7 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
credential_id: Optional[str] = None,
|
||||
credential_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke tool
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -24,7 +23,7 @@ class EndpointProviderDeclaration(BaseModel):
|
||||
"""
|
||||
|
||||
settings: list[ProviderConfig] = Field(default_factory=list)
|
||||
endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list[EndpointDeclaration])
|
||||
endpoints: list[EndpointDeclaration] | None = Field(default_factory=list[EndpointDeclaration])
|
||||
|
||||
|
||||
class EndpointEntity(BasePluginEntity):
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
@ -19,11 +17,11 @@ class MarketplacePluginDeclaration(BaseModel):
|
||||
resource: PluginResourceRequirements = Field(
|
||||
..., description="Specification of computational resources needed to run the plugin"
|
||||
)
|
||||
endpoint: Optional[EndpointProviderDeclaration] = Field(
|
||||
endpoint: EndpointProviderDeclaration | None = Field(
|
||||
None, description="Configuration for the plugin's API endpoint, if applicable"
|
||||
)
|
||||
model: Optional[ProviderEntity] = Field(None, description="Details of the AI model used by the plugin, if any")
|
||||
tool: Optional[ToolProviderEntity] = Field(
|
||||
model: ProviderEntity | None = Field(None, description="Details of the AI model used by the plugin, if any")
|
||||
tool: ToolProviderEntity | None = Field(
|
||||
None, description="Information about the tool functionality provided by the plugin, if any"
|
||||
)
|
||||
latest_version: str = Field(
|
||||
|
||||
21
api/core/plugin/entities/oauth.py
Normal file
21
api/core/plugin/entities/oauth.py
Normal file
@ -0,0 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
"""
|
||||
OAuth schema
|
||||
"""
|
||||
|
||||
client_schema: Sequence[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="client schema like client_id, client_secret, etc.",
|
||||
)
|
||||
|
||||
credentials_schema: Sequence[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="credentials schema like access_token, refresh_token, etc.",
|
||||
)
|
||||
@ -1,19 +1,17 @@
|
||||
import enum
|
||||
from typing import Any, Optional, Union
|
||||
import json
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.workflow.nodes.base.entities import NumberType
|
||||
|
||||
|
||||
class PluginParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
icon: Optional[str] = Field(
|
||||
default=None, description="The icon of the option, can be a url or a base64 encoded image"
|
||||
)
|
||||
icon: str | None = Field(default=None, description="The icon of the option, can be a url or a base64 encoded image")
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
@ -24,44 +22,44 @@ class PluginParameterOption(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
class PluginParameterType(enum.StrEnum):
|
||||
class PluginParameterType(StrEnum):
|
||||
"""
|
||||
all available parameter types
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
ANY = CommonParameterType.ANY.value
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
|
||||
STRING = CommonParameterType.STRING
|
||||
NUMBER = CommonParameterType.NUMBER
|
||||
BOOLEAN = CommonParameterType.BOOLEAN
|
||||
SELECT = CommonParameterType.SELECT
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT
|
||||
FILE = CommonParameterType.FILE
|
||||
FILES = CommonParameterType.FILES
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
|
||||
ANY = CommonParameterType.ANY
|
||||
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = CommonParameterType.ARRAY.value
|
||||
OBJECT = CommonParameterType.OBJECT.value
|
||||
ARRAY = CommonParameterType.ARRAY
|
||||
OBJECT = CommonParameterType.OBJECT
|
||||
|
||||
|
||||
class MCPServerParameterType(enum.StrEnum):
|
||||
class MCPServerParameterType(StrEnum):
|
||||
"""
|
||||
MCP server got complex parameter types
|
||||
"""
|
||||
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
ARRAY = auto()
|
||||
OBJECT = auto()
|
||||
|
||||
|
||||
class PluginParameterAutoGenerate(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
PROMPT_INSTRUCTION = "prompt_instruction"
|
||||
class Type(StrEnum):
|
||||
PROMPT_INSTRUCTION = auto()
|
||||
|
||||
type: Type
|
||||
|
||||
@ -73,15 +71,15 @@ class PluginParameterTemplate(BaseModel):
|
||||
class PluginParameter(BaseModel):
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
|
||||
placeholder: I18nObject | None = Field(default=None, description="The placeholder presented to the user")
|
||||
scope: str | None = None
|
||||
auto_generate: Optional[PluginParameterAutoGenerate] = None
|
||||
template: Optional[PluginParameterTemplate] = None
|
||||
auto_generate: PluginParameterAutoGenerate | None = None
|
||||
template: PluginParameterTemplate | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[float, int, str]] = None
|
||||
min: Optional[Union[float, int]] = None
|
||||
max: Optional[Union[float, int]] = None
|
||||
precision: Optional[int] = None
|
||||
default: Union[float, int, str] | None = None
|
||||
min: Union[float, int] | None = None
|
||||
max: Union[float, int] | None = None
|
||||
precision: int | None = None
|
||||
options: list[PluginParameterOption] = Field(default_factory=list)
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@ -92,7 +90,7 @@ class PluginParameter(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
def as_normal_type(typ: enum.StrEnum):
|
||||
def as_normal_type(typ: StrEnum):
|
||||
if typ.value in {
|
||||
PluginParameterType.SECRET_INPUT,
|
||||
PluginParameterType.SELECT,
|
||||
@ -101,7 +99,7 @@ def as_normal_type(typ: enum.StrEnum):
|
||||
return typ.value
|
||||
|
||||
|
||||
def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||
def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
try:
|
||||
match typ.value:
|
||||
case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT:
|
||||
@ -154,7 +152,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||
raise ValueError("The tools selector must be a list.")
|
||||
return value
|
||||
case PluginParameterType.ANY:
|
||||
if value and not isinstance(value, str | dict | list | NumberType):
|
||||
if value and not isinstance(value, str | dict | list | int | float):
|
||||
raise ValueError("The var selector must be a string, dictionary, list or number.")
|
||||
return value
|
||||
case PluginParameterType.ARRAY:
|
||||
@ -162,8 +160,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||
# Try to parse JSON string for arrays
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
parsed_value = json.loads(value)
|
||||
if isinstance(parsed_value, list):
|
||||
return parsed_value
|
||||
@ -176,8 +172,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||
# Try to parse JSON string for objects
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
parsed_value = json.loads(value)
|
||||
if isinstance(parsed_value, dict):
|
||||
return parsed_value
|
||||
@ -193,7 +187,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: Any):
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
"""
|
||||
init frontend parameter by rule
|
||||
"""
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import datetime
|
||||
import enum
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from werkzeug.exceptions import NotFound
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
@ -16,11 +16,11 @@ from core.tools.entities.tool_entities import ToolProviderEntity
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
|
||||
class PluginInstallationSource(enum.StrEnum):
|
||||
Github = "github"
|
||||
Marketplace = "marketplace"
|
||||
Package = "package"
|
||||
Remote = "remote"
|
||||
class PluginInstallationSource(StrEnum):
|
||||
Github = auto()
|
||||
Marketplace = auto()
|
||||
Package = auto()
|
||||
Remote = auto()
|
||||
|
||||
|
||||
class PluginResourceRequirements(BaseModel):
|
||||
@ -28,84 +28,109 @@ class PluginResourceRequirements(BaseModel):
|
||||
|
||||
class Permission(BaseModel):
|
||||
class Tool(BaseModel):
|
||||
enabled: Optional[bool] = Field(default=False)
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Model(BaseModel):
|
||||
enabled: Optional[bool] = Field(default=False)
|
||||
llm: Optional[bool] = Field(default=False)
|
||||
text_embedding: Optional[bool] = Field(default=False)
|
||||
rerank: Optional[bool] = Field(default=False)
|
||||
tts: Optional[bool] = Field(default=False)
|
||||
speech2text: Optional[bool] = Field(default=False)
|
||||
moderation: Optional[bool] = Field(default=False)
|
||||
enabled: bool | None = Field(default=False)
|
||||
llm: bool | None = Field(default=False)
|
||||
text_embedding: bool | None = Field(default=False)
|
||||
rerank: bool | None = Field(default=False)
|
||||
tts: bool | None = Field(default=False)
|
||||
speech2text: bool | None = Field(default=False)
|
||||
moderation: bool | None = Field(default=False)
|
||||
|
||||
class Node(BaseModel):
|
||||
enabled: Optional[bool] = Field(default=False)
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Endpoint(BaseModel):
|
||||
enabled: Optional[bool] = Field(default=False)
|
||||
enabled: bool | None = Field(default=False)
|
||||
|
||||
class Storage(BaseModel):
|
||||
enabled: Optional[bool] = Field(default=False)
|
||||
enabled: bool | None = Field(default=False)
|
||||
size: int = Field(ge=1024, le=1073741824, default=1048576)
|
||||
|
||||
tool: Optional[Tool] = Field(default=None)
|
||||
model: Optional[Model] = Field(default=None)
|
||||
node: Optional[Node] = Field(default=None)
|
||||
endpoint: Optional[Endpoint] = Field(default=None)
|
||||
storage: Optional[Storage] = Field(default=None)
|
||||
tool: Tool | None = Field(default=None)
|
||||
model: Model | None = Field(default=None)
|
||||
node: Node | None = Field(default=None)
|
||||
endpoint: Endpoint | None = Field(default=None)
|
||||
storage: Storage | None = Field(default=None)
|
||||
|
||||
permission: Optional[Permission] = Field(default=None)
|
||||
permission: Permission | None = Field(default=None)
|
||||
|
||||
|
||||
class PluginCategory(enum.StrEnum):
|
||||
Tool = "tool"
|
||||
Model = "model"
|
||||
Extension = "extension"
|
||||
class PluginCategory(StrEnum):
|
||||
Tool = auto()
|
||||
Model = auto()
|
||||
Extension = auto()
|
||||
AgentStrategy = "agent-strategy"
|
||||
Datasource = "datasource"
|
||||
Trigger = "trigger"
|
||||
|
||||
|
||||
class PluginDeclaration(BaseModel):
|
||||
class Plugins(BaseModel):
|
||||
tools: Optional[list[str]] = Field(default_factory=list[str])
|
||||
models: Optional[list[str]] = Field(default_factory=list[str])
|
||||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||
tools: list[str] | None = Field(default_factory=list[str])
|
||||
models: list[str] | None = Field(default_factory=list[str])
|
||||
endpoints: list[str] | None = Field(default_factory=list[str])
|
||||
datasources: list[str] | None = Field(default_factory=list[str])
|
||||
triggers: Optional[list[str]] = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
version: Optional[str] = Field(default=None)
|
||||
minimum_dify_version: str | None = Field(default=None)
|
||||
version: str | None = Field(default=None)
|
||||
|
||||
version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
@field_validator("minimum_dify_version")
|
||||
@classmethod
|
||||
def validate_minimum_dify_version(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
try:
|
||||
Version(v)
|
||||
return v
|
||||
except InvalidVersion as e:
|
||||
raise ValueError(f"Invalid version format: {v}") from e
|
||||
|
||||
version: str = Field(...)
|
||||
author: str | None = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
|
||||
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
|
||||
description: I18nObject
|
||||
icon: str
|
||||
icon_dark: Optional[str] = Field(default=None)
|
||||
icon_dark: str | None = Field(default=None)
|
||||
label: I18nObject
|
||||
category: PluginCategory
|
||||
created_at: datetime.datetime
|
||||
resource: PluginResourceRequirements
|
||||
plugins: Plugins
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
repo: Optional[str] = Field(default=None)
|
||||
repo: str | None = Field(default=None)
|
||||
verified: bool = Field(default=False)
|
||||
tool: Optional[ToolProviderEntity] = None
|
||||
tool: ToolProviderEntity | None = None
|
||||
model: ProviderEntity | None = None
|
||||
endpoint: EndpointProviderDeclaration | None = None
|
||||
agent_strategy: AgentStrategyProviderEntity | None = None
|
||||
datasource: DatasourceProviderEntity | None = None
|
||||
trigger: Optional[TriggerProviderEntity] = None
|
||||
model: Optional[ProviderEntity] = None
|
||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||
meta: Meta
|
||||
|
||||
@field_validator("version")
|
||||
@classmethod
|
||||
def validate_version(cls, v: str) -> str:
|
||||
try:
|
||||
Version(v)
|
||||
return v
|
||||
except InvalidVersion as e:
|
||||
raise ValueError(f"Invalid version format: {v}") from e
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_category(cls, values: dict) -> dict:
|
||||
def validate_category(cls, values: dict):
|
||||
# auto detect category
|
||||
if values.get("tool"):
|
||||
values["category"] = PluginCategory.Tool
|
||||
elif values.get("model"):
|
||||
values["category"] = PluginCategory.Model
|
||||
elif values.get("datasource"):
|
||||
values["category"] = PluginCategory.Datasource
|
||||
elif values.get("agent_strategy"):
|
||||
values["category"] = PluginCategory.AgentStrategy
|
||||
elif values.get("trigger"):
|
||||
@ -141,64 +166,11 @@ class PluginEntity(PluginInstallation):
|
||||
return self
|
||||
|
||||
|
||||
class GenericProviderID:
|
||||
organization: str
|
||||
plugin_name: str
|
||||
provider_name: str
|
||||
is_hardcoded: bool
|
||||
|
||||
def to_string(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
|
||||
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
if not value:
|
||||
raise NotFound("plugin not found, please add plugin")
|
||||
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
|
||||
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
|
||||
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
|
||||
if re.match(r"^[a-z0-9_-]+$", value):
|
||||
value = f"langgenius/{value}/{value}"
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin id {value}")
|
||||
|
||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
||||
self.is_hardcoded = is_hardcoded
|
||||
|
||||
def is_langgenius(self) -> bool:
|
||||
return self.organization == "langgenius"
|
||||
|
||||
@property
|
||||
def plugin_id(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}"
|
||||
|
||||
|
||||
class ModelProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius" and self.provider_name == "google":
|
||||
self.plugin_name = "gemini"
|
||||
|
||||
|
||||
class ToolProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius":
|
||||
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
class TriggerProviderID(GenericProviderID):
|
||||
pass
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
Github = PluginInstallationSource.Github.value
|
||||
Marketplace = PluginInstallationSource.Marketplace.value
|
||||
Package = PluginInstallationSource.Package.value
|
||||
class Type(StrEnum):
|
||||
Github = PluginInstallationSource.Github
|
||||
Marketplace = PluginInstallationSource.Marketplace
|
||||
Package = PluginInstallationSource.Package
|
||||
|
||||
class Github(BaseModel):
|
||||
repo: str
|
||||
@ -212,6 +184,7 @@ class PluginDependency(BaseModel):
|
||||
|
||||
class Marketplace(BaseModel):
|
||||
marketplace_plugin_unique_identifier: str
|
||||
version: str | None = None
|
||||
|
||||
@property
|
||||
def plugin_unique_identifier(self) -> str:
|
||||
@ -219,12 +192,13 @@ class PluginDependency(BaseModel):
|
||||
|
||||
class Package(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
version: str | None = None
|
||||
|
||||
type: Type
|
||||
value: Github | Marketplace | Package
|
||||
current_identifier: Optional[str] = None
|
||||
current_identifier: str | None = None
|
||||
|
||||
|
||||
class MissingPluginDependency(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
current_identifier: Optional[str] = None
|
||||
current_identifier: str | None = None
|
||||
|
||||
@ -2,11 +2,12 @@ import enum
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.agent.plugin_entities import AgentProviderEntityWithPlugin
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.base import BasePluginEntity
|
||||
@ -26,7 +27,7 @@ class PluginDaemonBasicResponse(BaseModel, Generic[T]):
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[T]
|
||||
data: T | None = None
|
||||
|
||||
|
||||
class InstallPluginMessage(BaseModel):
|
||||
@ -50,6 +51,14 @@ class PluginToolProviderEntity(BaseModel):
|
||||
declaration: ToolProviderEntityWithPlugin
|
||||
|
||||
|
||||
class PluginDatasourceProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
is_authorized: bool = False
|
||||
declaration: DatasourceProviderEntityWithPlugin
|
||||
|
||||
|
||||
class PluginAgentProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
@ -176,7 +185,7 @@ class PluginVerification(BaseModel):
|
||||
class PluginDecodeResponse(BaseModel):
|
||||
unique_identifier: str = Field(description="The unique identifier of the plugin.")
|
||||
manifest: PluginDeclaration
|
||||
verification: Optional[PluginVerification] = Field(default=None, description="Basic verification information")
|
||||
verification: PluginVerification | None = Field(default=None, description="Basic verification information")
|
||||
|
||||
|
||||
class PluginOAuthAuthorizationUrlResponse(BaseModel):
|
||||
@ -235,9 +244,9 @@ class CredentialType(enum.StrEnum):
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api-key":
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
elif type_name in {"oauth2", "oauth"}:
|
||||
return cls.OAUTH2
|
||||
elif type_name == "unauthorized":
|
||||
return cls.UNAUTHORIZED
|
||||
|
||||
@ -37,7 +37,7 @@ class InvokeCredentials(BaseModel):
|
||||
|
||||
|
||||
class PluginInvokeContext(BaseModel):
|
||||
credentials: Optional[InvokeCredentials] = Field(
|
||||
credentials: InvokeCredentials | None = Field(
|
||||
default_factory=InvokeCredentials,
|
||||
description="Credentials context for the plugin invocation or backward invocation.",
|
||||
)
|
||||
@ -52,7 +52,7 @@ class RequestInvokeTool(BaseModel):
|
||||
provider: str
|
||||
tool: str
|
||||
tool_parameters: dict
|
||||
credential_id: Optional[str] = None
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
class BaseRequestInvokeModel(BaseModel):
|
||||
@ -72,9 +72,9 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
||||
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool])
|
||||
stop: Optional[list[str]] = Field(default_factory=list[str])
|
||||
stream: Optional[bool] = False
|
||||
tools: list[PromptMessageTool] | None = Field(default_factory=list[PromptMessageTool])
|
||||
stop: list[str] | None = Field(default_factory=list[str])
|
||||
stream: bool | None = False
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@ -196,10 +196,10 @@ class RequestInvokeApp(BaseModel):
|
||||
|
||||
app_id: str
|
||||
inputs: dict[str, Any]
|
||||
query: Optional[str] = None
|
||||
query: str | None = None
|
||||
response_mode: Literal["blocking", "streaming"]
|
||||
conversation_id: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
conversation_id: str | None = None
|
||||
user: str | None = None
|
||||
files: list[dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginAgentProviderEntity,
|
||||
)
|
||||
from core.plugin.entities.request import PluginInvokeContext
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class PluginAgentClient(BasePluginClient):
|
||||
@ -17,7 +17,7 @@ class PluginAgentClient(BasePluginClient):
|
||||
Fetch agent providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
@ -49,7 +49,7 @@ class PluginAgentClient(BasePluginClient):
|
||||
"""
|
||||
agent_provider_id = GenericProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]):
|
||||
# skip if error occurs
|
||||
if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None:
|
||||
return json_response
|
||||
@ -82,10 +82,10 @@ class PluginAgentClient(BasePluginClient):
|
||||
agent_provider: str,
|
||||
agent_strategy: str,
|
||||
agent_params: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
context: Optional[PluginInvokeContext] = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
context: PluginInvokeContext | None = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
|
||||
|
||||
@ -64,7 +64,7 @@ class BasePluginClient:
|
||||
response = requests.request(
|
||||
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
except requests.ConnectionError:
|
||||
logger.exception("Request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
|
||||
372
api/core/plugin/impl/datasource.py
Normal file
372
api/core/plugin/impl/datasource.py
Normal file
@ -0,0 +1,372 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDocumentPagesMessage,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDatasourceProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.schemas.resolver import resolve_dify_schema_refs
|
||||
from models.provider_ids import DatasourceProviderID, GenericProviderID
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class PluginDatasourceManager(BasePluginClient):
|
||||
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for datasource in declaration.get("datasources", []):
|
||||
datasource["identity"]["provider"] = provider_name
|
||||
# resolve refs
|
||||
if datasource.get("output_schema"):
|
||||
datasource["output_schema"] = resolve_dify_schema_refs(datasource["output_schema"])
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginDatasourceProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
|
||||
for provider in response:
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
all_response = [local_file_datasource_provider] + response
|
||||
|
||||
for provider in all_response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.datasources:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return all_response
|
||||
|
||||
def fetch_installed_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||
"""
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for datasource in declaration.get("datasources", []):
|
||||
datasource["identity"]["provider"] = provider_name
|
||||
# resolve refs
|
||||
if datasource.get("output_schema"):
|
||||
datasource["output_schema"] = resolve_dify_schema_refs(datasource["output_schema"])
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasources",
|
||||
list[PluginDatasourceProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.datasources:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
if provider_id == "langgenius/file/file":
|
||||
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
|
||||
tool_provider_id = DatasourceProviderID(provider_id)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for datasource in data.get("declaration", {}).get("datasources", []):
|
||||
datasource["identity"]["provider"] = tool_provider_id.provider_name
|
||||
if datasource.get("output_schema"):
|
||||
datasource["output_schema"] = resolve_dify_schema_refs(datasource["output_schema"])
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/datasource",
|
||||
PluginDatasourceProviderEntity,
|
||||
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for datasource in response.declaration.datasources:
|
||||
datasource.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def get_website_crawl(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[WebsiteCrawlMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
return self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl",
|
||||
WebsiteCrawlMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"datasource_parameters": datasource_parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_online_document_pages(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
return self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
|
||||
OnlineDocumentPagesMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"datasource_parameters": datasource_parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def get_online_document_page_content(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
return self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content",
|
||||
DatasourceMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"page": datasource_parameters.model_dump(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def online_drive_browse_files(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
request: OnlineDriveBrowseFilesRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/online_drive_browse_files",
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"request": request.model_dump(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
yield from response
|
||||
|
||||
def online_drive_download_file(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
datasource_provider: str,
|
||||
datasource_name: str,
|
||||
credentials: dict[str, Any],
|
||||
request: OnlineDriveDownloadFileRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/online_drive_download_file",
|
||||
DatasourceMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": datasource_provider_id.provider_name,
|
||||
"datasource": datasource_name,
|
||||
"credentials": credentials,
|
||||
"request": request.model_dump(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
yield from response
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
# datasource_provider_id = GenericProviderID(provider_id)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": "langgenius/file/file",
|
||||
"plugin_id": "langgenius/file",
|
||||
"provider": "file",
|
||||
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
|
||||
"declaration": {
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "file",
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
"icon": "https://assets.dify.ai/images/File%20Upload.svg",
|
||||
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
},
|
||||
"credentials_schema": [],
|
||||
"provider_type": "local_file",
|
||||
"datasources": [
|
||||
{
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "upload-file",
|
||||
"provider": "file",
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
},
|
||||
"parameters": [],
|
||||
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
@ -1,9 +1,9 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class DynamicSelectClient(BasePluginClient):
|
||||
|
||||
@ -8,7 +8,7 @@ from extensions.ext_logging import get_request_id
|
||||
class PluginDaemonError(Exception):
|
||||
"""Base class for all plugin daemon errors."""
|
||||
|
||||
def __init__(self, description: str) -> None:
|
||||
def __init__(self, description: str):
|
||||
self.description = description
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import binascii
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import IO, Optional
|
||||
from typing import IO
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
@ -151,9 +151,9 @@ class PluginModelClient(BasePluginClient):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
model_parameters: dict | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
@ -200,7 +200,7 @@ class PluginModelClient(BasePluginClient):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for llm
|
||||
@ -325,8 +325,8 @@ class PluginModelClient(BasePluginClient):
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank
|
||||
@ -414,8 +414,8 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
language: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
language: str | None = None,
|
||||
):
|
||||
"""
|
||||
Get tts model voices
|
||||
"""
|
||||
|
||||
@ -2,7 +2,6 @@ from collections.abc import Sequence
|
||||
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
MissingPluginDependency,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
@ -16,6 +15,7 @@ from core.plugin.entities.plugin_daemon import (
|
||||
PluginListResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class PluginInstaller(BasePluginClient):
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
# from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from core.schemas.resolver import resolve_dify_schema_refs
|
||||
|
||||
# from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from models.provider_ids import GenericProviderID, ToolProviderID
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginClient):
|
||||
@ -16,12 +20,15 @@ class PluginToolManager(BasePluginClient):
|
||||
Fetch tool providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for tool in declaration.get("tools", []):
|
||||
tool["identity"]["provider"] = provider_name
|
||||
# resolve refs
|
||||
if tool.get("output_schema"):
|
||||
tool["output_schema"] = resolve_dify_schema_refs(tool["output_schema"])
|
||||
|
||||
return json_response
|
||||
|
||||
@ -48,11 +55,14 @@ class PluginToolManager(BasePluginClient):
|
||||
"""
|
||||
tool_provider_id = ToolProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]):
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for tool in data.get("declaration", {}).get("tools", []):
|
||||
tool["identity"]["provider"] = tool_provider_id.provider_name
|
||||
# resolve refs
|
||||
if tool.get("output_schema"):
|
||||
tool["output_schema"] = resolve_dify_schema_refs(tool["output_schema"])
|
||||
|
||||
return json_response
|
||||
|
||||
@ -81,9 +91,9 @@ class PluginToolManager(BasePluginClient):
|
||||
credentials: dict[str, Any],
|
||||
credential_type: CredentialType,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the tool with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
@ -146,6 +156,36 @@ class PluginToolManager(BasePluginClient):
|
||||
|
||||
return False
|
||||
|
||||
def validate_datasource_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the datasource
|
||||
"""
|
||||
tool_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": tool_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -153,9 +193,9 @@ class PluginToolManager(BasePluginClient):
|
||||
provider: str,
|
||||
credentials: dict[str, Any],
|
||||
tool: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters of the tool
|
||||
|
||||
@ -18,7 +18,7 @@ class FileChunk:
|
||||
bytes_written: int = field(default=0, init=False)
|
||||
data: bytearray = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def __post_init__(self):
|
||||
self.data = bytearray(self.total_length)
|
||||
|
||||
|
||||
@ -82,7 +82,9 @@ def merge_blob_chunks(
|
||||
message_class = type(resp)
|
||||
merged_message = message_class(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]),
|
||||
message=ToolInvokeMessage.BlobMessage(
|
||||
blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written])
|
||||
),
|
||||
meta=resp.meta,
|
||||
)
|
||||
yield cast(MessageType, merged_message)
|
||||
|
||||
Reference in New Issue
Block a user