Compare commits

..

3 Commits

149 changed files with 3502 additions and 4433 deletions

2
.gitignore vendored
View File

@ -257,5 +257,5 @@ scripts/stress-test/reports/
# Code Agent Folder
.qoder/*
.context/
.context/*
.eslintcache

View File

@ -30,7 +30,7 @@ from clients.agent_backend.factory import create_agent_backend_run_client
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
from clients.agent_backend.request_builder import (
AGENT_SOUL_PROMPT_LAYER_ID,
DIFY_EXECUTION_CONTEXT_LAYER_ID,
DIFY_PLUGIN_CONTEXT_LAYER_ID,
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
AgentBackendModelConfig,
@ -42,7 +42,7 @@ from clients.agent_backend.request_builder import (
__all__ = [
"AGENT_SOUL_PROMPT_LAYER_ID",
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
"DIFY_PLUGIN_CONTEXT_LAYER_ID",
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
"WORKFLOW_USER_PROMPT_LAYER_ID",
"AgentBackendError",

View File

@ -4,9 +4,7 @@ This module is intentionally an adapter, not a wire DTO package. The emitted
object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend
protocol has a single owner. API-only context such as Agent Soul vs workflow job
prompt is preserved in layer names and metadata until the dedicated product
schemas land in later phases. Dify-owned execution identifiers are emitted as an
explicit ``dify.execution_context`` layer so the run request stays fully
composition-driven.
schemas land in later phases.
"""
from __future__ import annotations
@ -17,19 +15,18 @@ from agenton.compositor import CompositorSessionSnapshot
from agenton.layers import ExitIntent
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DifyPluginCredentialValue,
DifyPluginLayerConfig,
DifyPluginLLMLayerConfig,
)
from dify_agent.layers.execution_context import (
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
DifyExecutionContextLayerConfig,
)
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.protocol import (
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
CreateRunRequest,
ExecutionContext,
LayerExitSignals,
RunComposition,
RunLayerSpec,
@ -40,15 +37,17 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin"
class AgentBackendModelConfig(BaseModel):
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
tenant_id: str
plugin_id: str
model_provider: str
model: str
user_id: str | None = None
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
@ -74,7 +73,7 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
"""Inputs needed to build the first workflow-node-oriented Agent backend run request."""
model: AgentBackendModelConfig
execution_context: DifyExecutionContextLayerConfig
execution_context: ExecutionContext
workflow_node_job_prompt: str
user_prompt: str
agent_soul_prompt: str | None = None
@ -126,18 +125,21 @@ class AgentBackendRunRequestBuilder:
config=PromptLayerConfig(user=run_input.user_prompt),
),
RunLayerSpec(
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
name=DIFY_PLUGIN_CONTEXT_LAYER_ID,
type=DIFY_PLUGIN_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.execution_context,
config=DifyPluginLayerConfig(
tenant_id=run_input.model.tenant_id,
plugin_id=run_input.model.plugin_id,
user_id=run_input.model.user_id,
),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=DifyPluginLLMLayerConfig(
plugin_id=run_input.model.plugin_id,
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
@ -163,6 +165,7 @@ class AgentBackendRunRequestBuilder:
return CreateRunRequest(
composition=RunComposition(layers=layers),
execution_context=run_input.execution_context,
purpose=run_input.purpose,
idempotency_key=run_input.idempotency_key,
metadata=run_input.metadata,

View File

@ -22,6 +22,9 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
@ -147,9 +150,44 @@ class BaseAgentRunner(AppRunner):
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.entity.description.llm,
parameters=tool_entity.get_llm_parameters_json_schema(),
parameters={
"type": "object",
"properties": {},
"required": [],
},
)
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
message_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
message_tool.parameters["required"].append(parameter.name)
return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
@ -214,7 +252,40 @@ class BaseAgentRunner(AppRunner):
"""
update prompt message tool
"""
prompt_tool.parameters = tool.get_llm_parameters_json_schema()
# try to get tool runtime parameters
tool_runtime_parameters = tool.get_runtime_parameters()
for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
prompt_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
return prompt_tool
def create_agent_thought(

View File

@ -126,89 +126,34 @@ class Tool(ABC):
message_id: str | None = None,
) -> list[ToolParameter]:
"""
Get the effective parameter declarations for this tool.
Runtime parameters override declared parameters by name and append new
parameters, but the returned list is always detached from the tool's
cached declarations so callers can safely mutate it while building
downstream schemas.
get merged runtime parameters
:return: merged runtime parameters
"""
parameters = [deepcopy(parameter) for parameter in self.entity.parameters or []]
user_parameters = [
deepcopy(parameter)
for parameter in self.get_runtime_parameters(
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
or []
]
parameter_indexes = {parameter.name: index for index, parameter in enumerate(parameters)}
parameters = self.entity.parameters
parameters = parameters.copy()
user_parameters = self.get_runtime_parameters() or []
user_parameters = user_parameters.copy()
# override parameters
for parameter in user_parameters:
existing_index = parameter_indexes.get(parameter.name)
if existing_index is None:
parameter_indexes[parameter.name] = len(parameters)
# check if parameter in tool parameters
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
break
else:
# add new parameter
parameters.append(parameter)
continue
parameters[existing_index] = parameter
return parameters
def get_llm_parameters_json_schema(
self,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> dict[str, Any]:
"""Build the model-visible JSON schema from effective tool parameters.
Hidden/manual parameters stay available for invocation preparation on the
API side, but are intentionally omitted from the LLM-facing schema.
"""
schema: dict[str, Any] = {
"type": "object",
"properties": {},
"required": [],
}
for parameter in self.get_merged_runtime_parameters(
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
parameter_schema: dict[str, Any] = (
{
"type": parameter.type.as_normal_type(),
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else deepcopy(parameter.input_schema)
)
parameter_schema.setdefault("description", parameter.llm_description or "")
if parameter.type == ToolParameter.ToolParameterType.SELECT and parameter.options:
parameter_schema["enum"] = [option.value for option in parameter.options]
schema["properties"][parameter.name] = parameter_schema
if parameter.required:
schema["required"].append(parameter.name)
return schema
def create_image_message(
self,
image: str,

View File

@ -4,8 +4,7 @@ from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Literal, Protocol, cast
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
from dify_agent.protocol import CreateRunRequest
from dify_agent.protocol import CreateRunRequest, ExecutionContext
from clients.agent_backend import (
AgentBackendModelConfig,
@ -106,20 +105,16 @@ class WorkflowAgentRuntimeRequestBuilder:
request = self._request_builder.build_for_workflow_node(
AgentBackendWorkflowNodeRunInput(
model=AgentBackendModelConfig(
tenant_id=context.dify_context.tenant_id,
plugin_id=agent_soul.model.plugin_id,
model_provider=agent_soul.model.model_provider,
model=agent_soul.model.model,
user_id=context.dify_context.user_id,
credentials=self._normalize_credentials(credentials),
model_settings=cast(dict[str, Any], agent_soul.model.model_settings),
),
# The execution-context layer is now the only public protocol
# carrier for Dify tenant/user/run identifiers. ``user_id`` must
# be forwarded here because downstream plugin-daemon provider and
# tool clients read it from this layer rather than from any
# parallel top-level request field.
execution_context=DifyExecutionContextLayerConfig(
execution_context=ExecutionContext(
tenant_id=context.dify_context.tenant_id,
user_id=context.dify_context.user_id,
app_id=context.dify_context.app_id,
workflow_id=context.workflow_id,
workflow_run_id=context.workflow_run_id,

View File

@ -8,9 +8,17 @@ decisions, row updates, row deletes, and structured logging. Only some grouped t
also add cache cleanup; that includes `provider_models` and
`provider_model_credentials`. Provider-model-credential groups extend that flow by
rewriting credential references in provider models and load-balancing configs before
removing loser credential rows. `load_balancing_model_configs` is the intentional
exception: it does not group or merge rows, and instead reloads and canonicalizes each
legacy row independently with row-level cache cleanup.
removing loser credential rows. `load_balancing_model_configs` stays mostly row-level,
but it first deduplicates `name="__inherit__"` rows by business key before it
canonicalizes the remaining legacy rows independently with row-level cache cleanup.
Tenant scheduling has two modes. When callers provide an explicit tenant list, the
service preserves the original tenant-scoped execution model and runs all selected tables
for each tenant. When callers omit `tenant_ids`, the service discovers tenant
ids per table and then runs only that table for the discovered tenants. Most
tables keep the active `model_types` filter in the discovery query, while
`load_balancing_model_configs` deliberately uses a whole-table tenant scan so
that query stays easy to understand.
"""
from __future__ import annotations
@ -19,8 +27,9 @@ import io
import json
import sys
import threading
import traceback
import uuid
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterable, Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, dataclass
from datetime import datetime
@ -35,7 +44,7 @@ from sqlalchemy.sql import select
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from graphon.model_runtime.entities.model_entities import ModelType
from libs.datetime_utils import naive_utc_now
from models import LoadBalancingModelConfig, ProviderModel, ProviderModelSetting, Tenant, TenantDefaultModel
from models import LoadBalancingModelConfig, ProviderModel, ProviderModelSetting, TenantDefaultModel
from models.base import TypeBase
from models.provider import ProviderModelCredential
@ -95,6 +104,10 @@ def _normalize_log_payload(value: object) -> object:
return f"<{type(value).__module__}.{type(value).__qualname__}>"
def _format_exception_stacktrace(exc: BaseException) -> str:
return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
@dataclass(frozen=True, slots=True)
class _RowWithRawModelType[T: TypeBase]:
row: T
@ -176,6 +189,16 @@ class _ProviderModelSettingBusinessKey(_BusinessKey):
model_type: ModelType
@dataclass(frozen=True, slots=True)
class _LoadBalancingModelConfigInheritBusinessKey(_BusinessKey):
"""Business key for `name="__inherit__"` load-balancing configs."""
tenant_id: str
provider_name: str
model_name: str
model_type: ModelType
@dataclass(frozen=True, slots=True)
class _ProviderModelCredentialBusinessKey(_BusinessKey):
"""Although `ProviderModelCredential` does not have the unique index
@ -210,6 +233,13 @@ class _ProviderModelSettingGroupPlan:
loser_rows: list[_RowWithRawModelType[ProviderModelSetting]]
@dataclass(frozen=True, slots=True)
class _LoadBalancingModelConfigInheritGroupPlan:
group_row_ids: list[str]
winner: _RowWithRawModelType[LoadBalancingModelConfig] | None
loser_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]]
@dataclass(frozen=True, slots=True)
class _ProviderModelReferenceRewritePlan:
row_id: str
@ -228,13 +258,6 @@ class _LoadBalancingCredentialRewritePlan:
new_encrypted_config: str | None
@dataclass(frozen=True, slots=True)
class _LoadBalancingCredentialDeletePlan:
row_id: str
old_credential_id: str | None
winner_credential_id: str
@dataclass(frozen=True, slots=True)
class _ProviderModelCredentialGroupPlan:
group_row_ids: list[str]
@ -242,7 +265,6 @@ class _ProviderModelCredentialGroupPlan:
loser_rows: list[_RowWithRawModelType[ProviderModelCredential]]
provider_model_rewrites: list[_ProviderModelReferenceRewritePlan]
load_balancing_rewrites: list[_LoadBalancingCredentialRewritePlan]
load_balancing_deletions: list[_LoadBalancingCredentialDeletePlan]
VALID_TABLE_NAMES: tuple[str, ...] = (
@ -277,6 +299,21 @@ _LOCK_TIMEOUT_FALLBACK_MESSAGES: tuple[str, ...] = (
_RAW_MODEL_TYPE_COLUMN = "_raw_model_type"
def _selected_legacy_values(model_types: Sequence[ModelType]) -> list[str]:
legacy_values: list[str] = []
for model_type in model_types:
legacy_values.extend(_CANONICAL_TO_LEGACY[model_type])
return legacy_values
def _selected_model_type_values(model_types: Sequence[ModelType]) -> list[str]:
model_type_values: list[str] = []
for model_type in model_types:
model_type_values.append(model_type.value)
model_type_values.extend(_CANONICAL_TO_LEGACY[model_type])
return list(dict.fromkeys(model_type_values))
def _session_factory(engine: sa.Engine) -> Session:
return Session(bind=engine, expire_on_commit=False)
@ -349,6 +386,12 @@ class LegacyModelTypeMigrationService:
`provider_model_credentials` is selected, that migration also rewrites references in
`provider_models` and `load_balancing_model_configs`. Tenant migrations can run in a
thread pool; JSONL output remains line-safe through a shared synchronized writer.
If `tenant_ids` is omitted, tenant discovery becomes table-scoped: each selected ORM
model loads its own tenant ids, then only that table is dispatched for those tenants.
Most tables keep the active model-type filter in discovery, while
`load_balancing_model_configs` intentionally uses the whole table so the tenant query
stays simple. This still avoids merging tenant ids across unrelated tables.
"""
_engine: sa.Engine
@ -412,22 +455,51 @@ class LegacyModelTypeMigrationService:
return tuple(ordered_models)
def migrate(self) -> None:
tenant_ids = tuple(self._iter_tenant_ids())
output = _ThreadSafeLineWriter(self._output)
if self._tenant_ids is not None:
self._migrate_explicit_tenants(output)
return
self._migrate_tables_with_discovered_tenants(output)
def _migrate_explicit_tenants(self, output: io.TextIOBase) -> None:
tenant_ids = self._tenant_ids
if not tenant_ids:
return
output = _ThreadSafeLineWriter(self._output)
self._run_migrations_for_tenants(tenant_ids, self._orm_models, output)
def _migrate_tables_with_discovered_tenants(self, output: io.TextIOBase) -> None:
for orm_model in self._orm_models:
tenant_ids = self._load_tenant_ids_for_model(orm_model)
if not tenant_ids:
continue
self._run_migrations_for_tenants(tenant_ids, (orm_model,), output)
def _run_migrations_for_tenants(
self,
tenant_ids: Sequence[str],
orm_models: Sequence[ORMModel],
output: io.TextIOBase,
) -> None:
if self._concurrency == 1 or len(tenant_ids) == 1:
for tenant_id in tenant_ids:
self._run_tenant_migration(tenant_id, output)
self._run_tenant_migration(tenant_id, orm_models, output)
return
with ThreadPoolExecutor(max_workers=min(self._concurrency, len(tenant_ids))) as executor:
futures = [executor.submit(self._run_tenant_migration, tenant_id, output) for tenant_id in tenant_ids]
futures = [
executor.submit(self._run_tenant_migration, tenant_id, orm_models, output) for tenant_id in tenant_ids
]
for future in as_completed(futures):
future.result()
def _run_tenant_migration(self, tenant_id: str, output: io.TextIOBase) -> None:
def _run_tenant_migration(
self,
tenant_id: str,
orm_models: Sequence[ORMModel],
output: io.TextIOBase,
) -> None:
"""
Execute one tenant migration with the shared, line-synchronized output stream.
"""
@ -438,18 +510,88 @@ class LegacyModelTypeMigrationService:
apply=self._apply,
output=output,
model_types=self._model_types,
orm_models=self._orm_models,
orm_models=orm_models,
).run()
def _iter_tenant_ids(self) -> Iterator[str]:
if self._tenant_ids is not None:
yield from self._tenant_ids
return
def _load_tenant_ids_for_model(self, orm_model: ORMModel) -> tuple[str, ...]:
"""
Discover only the tenants that have candidate rows for the current table.
In automatic tenant mode we keep discovery table-scoped so large shared tenant
populations do not force empty work for unrelated tables. Most table queries
still apply the active `model_types` filter before scheduling migrations, while
`load_balancing_model_configs` intentionally trades a wider tenant set for a
simpler discovery query.
"""
legacy_model_type_values = _selected_legacy_values(self._model_types)
with _session_factory(self._engine) as session:
tenant_ids = session.execute(select(Tenant.id).order_by(Tenant.id.asc())).scalars().all()
if orm_model is ProviderModel:
tenant_ids = (
session.execute(
select(ProviderModel.tenant_id)
.where(sa.type_coerce(ProviderModel.model_type, sa.String()).in_(legacy_model_type_values))
.distinct()
.order_by(ProviderModel.tenant_id.asc())
)
.scalars()
.all()
)
elif orm_model is TenantDefaultModel:
tenant_ids = (
session.execute(
select(TenantDefaultModel.tenant_id)
.where(sa.type_coerce(TenantDefaultModel.model_type, sa.String()).in_(legacy_model_type_values))
.distinct()
.order_by(TenantDefaultModel.tenant_id.asc())
)
.scalars()
.all()
)
elif orm_model is ProviderModelSetting:
tenant_ids = (
session.execute(
select(ProviderModelSetting.tenant_id)
.where(
sa.type_coerce(ProviderModelSetting.model_type, sa.String()).in_(legacy_model_type_values)
)
.distinct()
.order_by(ProviderModelSetting.tenant_id.asc())
)
.scalars()
.all()
)
elif orm_model is LoadBalancingModelConfig:
# Deliberately discover tenants from the whole table so the query stays
# easier to understand than the legacy/canonical mixed-row filter.
tenant_ids = (
session.execute(
select(LoadBalancingModelConfig.tenant_id)
.distinct()
.order_by(LoadBalancingModelConfig.tenant_id.asc())
)
.scalars()
.all()
)
elif orm_model is ProviderModelCredential:
tenant_ids = (
session.execute(
select(ProviderModelCredential.tenant_id)
.where(
sa.type_coerce(ProviderModelCredential.model_type, sa.String()).in_(
legacy_model_type_values
)
)
.distinct()
.order_by(ProviderModelCredential.tenant_id.asc())
)
.scalars()
.all()
)
else:
raise ValueError(f"unsupported orm model: {orm_model}")
yield from tenant_ids
return tuple(tenant_ids)
class Migration:
@ -518,14 +660,28 @@ class Migration:
)
def _selected_legacy_values(self) -> list[str]:
legacy_values: list[str] = []
for model_type in self._model_types:
legacy_values.extend(_CANONICAL_TO_LEGACY[model_type])
return legacy_values
return _selected_legacy_values(self._model_types)
def _selected_model_type_values(self) -> list[str]:
return _selected_model_type_values(self._model_types)
def _allowed_values_for_canonical_model_type(self, canonical_model_type: ModelType) -> tuple[str, ...]:
return (*_CANONICAL_TO_LEGACY[canonical_model_type], canonical_model_type.value)
def _normalize_selected_model_type(self, raw_model_type: str) -> ModelType | None:
canonical_model_type = _LEGACY_TO_CANONICAL.get(raw_model_type)
if canonical_model_type is not None:
return canonical_model_type
try:
parsed_model_type = ModelType(raw_model_type)
except ValueError:
return None
if parsed_model_type not in self._model_types:
return None
return parsed_model_type
def _has_legacy_rows[T: TypeBase](self, rows: Sequence[_RowWithRawModelType[T]]) -> bool:
return any(row.raw_model_type in _LEGACY_TO_CANONICAL for row in rows)
@ -1203,9 +1359,11 @@ class Migration:
"""
Migrate load-balancing configs row by row.
This table only needs model_type canonicalization. Unlike the grouped tables, it
must not merge rows by business key; each legacy candidate is reloaded and updated
independently so the migration remains a pure per-row rewrite plus cache cleanup.
This table first deduplicates `name="__inherit__"` rows per normalized
`(tenant_id, provider_name, model_name, model_type)` business key, then
canonicalizes the remaining legacy rows independently. The pre-pass must run
first so a legacy/canonical `__inherit__` pair keeps only the newest row before
the row-level canonicalization would collapse them onto the same canonical key.
"""
self._log_event(
"table_started",
@ -1217,6 +1375,7 @@ class Migration:
},
)
processed_inherit_groups = self._deduplicate_inherit_load_balancing_model_configs()
processed_rows = 0
last_id: str | None = None
@ -1237,10 +1396,217 @@ class Migration:
"tenant_id": self._tenant_id,
"apply": self._apply,
"table_name": LoadBalancingModelConfig.__tablename__,
"processed_inherit_groups": processed_inherit_groups,
"processed_rows": processed_rows,
},
)
def _deduplicate_inherit_load_balancing_model_configs(self) -> int:
seen_business_keys: dict[_LoadBalancingModelConfigInheritBusinessKey, list[str]] = {}
processed_groups = 0
last_id: str | None = None
while True:
candidates = self._load_load_balancing_inherit_candidates(last_id)
if not candidates:
break
for candidate in candidates:
last_id = str(candidate.row.id)
business_key = _LoadBalancingModelConfigInheritBusinessKey(
tenant_id=candidate.row.tenant_id,
provider_name=candidate.row.provider_name,
model_name=candidate.row.model_name,
model_type=candidate.canonical_model_type,
)
if business_key in seen_business_keys:
continue
seen_business_keys[business_key] = self._process_load_balancing_inherit_group(candidate, business_key)
processed_groups += 1
return processed_groups
def _load_load_balancing_inherit_candidates(
self, last_id: str | None
) -> list[_RowWithRawModelType[LoadBalancingModelConfig]]:
raw_model_type = sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN)
with _session_factory(self._engine) as session:
stmt = (
select(LoadBalancingModelConfig, raw_model_type)
.where(
LoadBalancingModelConfig.tenant_id == self._tenant_id,
LoadBalancingModelConfig.name == "__inherit__",
sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).in_(
self._selected_model_type_values()
),
)
.order_by(LoadBalancingModelConfig.id.asc())
.limit(self._batch_size)
)
if last_id is not None:
stmt = stmt.where(LoadBalancingModelConfig.id > last_id)
rows = session.execute(stmt).all()
wrapped_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]] = []
for load_balancing_model_config, raw_value in rows:
raw_model_type_value = str(raw_value)
canonical_model_type = self._normalize_selected_model_type(raw_model_type_value)
if canonical_model_type is None:
self._log_event(
event="invalid_model_type",
message=f"invalid model type: {raw_value}",
attrs={
"id": load_balancing_model_config.id,
"table_name": load_balancing_model_config.__tablename__,
},
)
continue
wrapped_rows.append(
_RowWithRawModelType(
row=load_balancing_model_config,
raw_model_type=raw_model_type_value,
canonical_model_type=canonical_model_type,
)
)
return wrapped_rows
def _load_load_balancing_inherit_group(
self,
session: Session,
candidate: _RowWithRawModelType[LoadBalancingModelConfig],
*,
lock_rows: bool,
) -> list[_RowWithRawModelType[LoadBalancingModelConfig]]:
raw_model_type = sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).label(_RAW_MODEL_TYPE_COLUMN)
stmt = (
select(LoadBalancingModelConfig, raw_model_type)
.where(
LoadBalancingModelConfig.tenant_id == candidate.row.tenant_id,
LoadBalancingModelConfig.provider_name == candidate.row.provider_name,
LoadBalancingModelConfig.model_name == candidate.row.model_name,
LoadBalancingModelConfig.name == "__inherit__",
sa.type_coerce(LoadBalancingModelConfig.model_type, sa.String()).in_(
self._allowed_values_for_canonical_model_type(candidate.canonical_model_type)
),
)
.order_by(LoadBalancingModelConfig.id.asc())
)
if lock_rows:
stmt = stmt.with_for_update()
rows = session.execute(stmt).all()
wrapped_rows: list[_RowWithRawModelType[LoadBalancingModelConfig]] = []
for load_balancing_model_config, raw_value in rows:
raw_model_type_value = str(raw_value)
canonical_model_type = self._normalize_selected_model_type(raw_model_type_value)
if canonical_model_type is None:
continue
wrapped_rows.append(
_RowWithRawModelType(
row=load_balancing_model_config,
raw_model_type=raw_model_type_value,
canonical_model_type=canonical_model_type,
)
)
return wrapped_rows
def _build_load_balancing_inherit_group_plan(
self,
session: Session,
candidate: _RowWithRawModelType[LoadBalancingModelConfig],
*,
lock_rows: bool,
) -> _LoadBalancingModelConfigInheritGroupPlan:
rows = self._load_load_balancing_inherit_group(session, candidate, lock_rows=lock_rows)
group_row_ids = [str(row.row.id) for row in rows]
if len(rows) <= 1:
return _LoadBalancingModelConfigInheritGroupPlan(group_row_ids=group_row_ids, winner=None, loser_rows=[])
winner = self._select_winner(rows)
return _LoadBalancingModelConfigInheritGroupPlan(
group_row_ids=group_row_ids,
winner=winner,
loser_rows=[row for row in rows if row.row.id != winner.row.id],
)
def _emit_load_balancing_inherit_group_plan(
self,
plan: _LoadBalancingModelConfigInheritGroupPlan,
*,
session: Session,
tx_id: str,
business_key: _LoadBalancingModelConfigInheritBusinessKey,
) -> None:
if plan.winner is None:
return
cache_plans: list[_CacheDeletePlan] = []
for loser in plan.loser_rows:
loser_row_id = str(loser.row.id)
if self._apply:
session.execute(sa.delete(LoadBalancingModelConfig).where(LoadBalancingModelConfig.id == loser_row_id))
self._log_row_deleted(
LoadBalancingModelConfig.__tablename__,
loser,
tx_id=tx_id,
business_key=business_key,
related_winner_id=str(plan.winner.row.id),
)
cache_plans.append(
_CacheDeletePlan(
tenant_id=self._tenant_id,
identity_id=loser_row_id,
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
table_name=LoadBalancingModelConfig.__tablename__,
row_id=loser_row_id,
tx_id=tx_id,
business_key=business_key,
)
)
self._log_cache_plans(cache_plans, apply=self._apply)
self._log_group_processed(
LoadBalancingModelConfig.__tablename__,
business_key,
plan.group_row_ids,
tx_id=tx_id,
)
def _process_load_balancing_inherit_group(
self,
candidate: _RowWithRawModelType[LoadBalancingModelConfig],
business_key: _LoadBalancingModelConfigInheritBusinessKey,
) -> list[str]:
tx_id = self._new_tx_id()
group_row_ids = [str(candidate.row.id)]
try:
with _session_factory(self._engine) as session, session.begin():
self._configure_lock_timeout(session)
plan = self._build_load_balancing_inherit_group_plan(session, candidate, lock_rows=True)
group_row_ids = plan.group_row_ids or group_row_ids
self._emit_load_balancing_inherit_group_plan(
plan,
session=session,
tx_id=tx_id,
business_key=business_key,
)
except OperationalError as exc:
if self._is_lock_timeout_error(exc):
self._log_lock_timeout(
LoadBalancingModelConfig.__tablename__,
str(candidate.row.id),
tx_id,
business_key,
exc,
)
return group_row_ids
raise
return group_row_ids
def _load_load_balancing_model_config_candidates(
self, last_id: str | None
) -> list[_RowWithRawModelType[LoadBalancingModelConfig]]:
@ -1344,10 +1710,11 @@ class Migration:
).delete()
self._log_event("cache_deleted", "Deleted related cache entry.", attrs)
except Exception as exc:
self._log_event(
self._log_exception_event(
"cache_delete_failed",
"Failed to delete related cache entry.",
{**attrs, "error": str(exc)},
attrs,
exc,
)
def _process_load_balancing_model_config_row(
@ -1534,17 +1901,10 @@ class Migration:
loser_rows=[],
provider_model_rewrites=[],
load_balancing_rewrites=[],
load_balancing_deletions=[],
)
winner = self._select_winner(rows)
loser_rows = [row for row in rows if row.row.id != winner.row.id]
load_balancing_rewrites, load_balancing_deletions = self._plan_load_balancing_reference_rewrites(
session,
winner,
loser_rows,
lock_rows=lock_rows,
)
return _ProviderModelCredentialGroupPlan(
group_row_ids=group_row_ids,
winner=winner,
@ -1555,8 +1915,12 @@ class Migration:
loser_rows,
lock_rows=lock_rows,
),
load_balancing_rewrites=load_balancing_rewrites,
load_balancing_deletions=load_balancing_deletions,
load_balancing_rewrites=self._plan_load_balancing_reference_rewrites(
session,
winner,
loser_rows,
lock_rows=lock_rows,
),
)
def _emit_provider_model_reference_rewrites(
@ -1661,52 +2025,6 @@ class Migration:
)
return cache_plans
def _emit_load_balancing_reference_deletions(
self,
session: Session,
deletions: Sequence[_LoadBalancingCredentialDeletePlan],
*,
winner_credential_id: str,
loser_credential_ids: Sequence[str],
tx_id: str,
business_key: _BusinessKey,
) -> list[_CacheDeletePlan]:
cache_plans: list[_CacheDeletePlan] = []
for deletion in deletions:
if self._apply:
session.execute(
sa.delete(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.id == deletion.row_id
)
)
self._log_event(
"row_deleted",
"Deleted duplicate load_balancing_model_config row (credential reference dedup).",
{
"table_name": LoadBalancingModelConfig.__tablename__,
"id": deletion.row_id,
"old_credential_id": deletion.old_credential_id,
"winner_credential_id": winner_credential_id,
"apply": self._apply,
"tx_id": tx_id,
"business_key": business_key,
"rewrite_kind": "credential_reference_dedup",
"loser_credential_ids": list(loser_credential_ids),
},
)
cache_plans.append(
_CacheDeletePlan(
tenant_id=self._tenant_id,
identity_id=deletion.row_id,
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
table_name=LoadBalancingModelConfig.__tablename__,
row_id=deletion.row_id,
tx_id=tx_id,
business_key=business_key,
)
)
return cache_plans
def _emit_provider_model_credential_group_plan(
self,
plan: _ProviderModelCredentialGroupPlan,
@ -1741,16 +2059,6 @@ class Migration:
business_key=business_key,
)
)
cache_plans.extend(
self._emit_load_balancing_reference_deletions(
session,
plan.load_balancing_deletions,
winner_credential_id=winner_credential_id,
loser_credential_ids=loser_credential_ids,
tx_id=tx_id,
business_key=business_key,
)
)
for loser in plan.loser_rows:
if self._apply:
@ -1864,16 +2172,12 @@ class Migration:
loser_rows: Sequence[_RowWithRawModelType[ProviderModelCredential]],
*,
lock_rows: bool,
) -> tuple[list[_LoadBalancingCredentialRewritePlan], list[_LoadBalancingCredentialDeletePlan]]:
) -> list[_LoadBalancingCredentialRewritePlan]:
loser_ids = [str(row.row.id) for row in loser_rows]
if not loser_ids:
return [], []
return []
winner_credential_id = str(winner.row.id)
winner_credential_name = winner.row.credential_name
winner_encrypted_config = winner.row.encrypted_config
stmt_loser = (
stmt = (
select(LoadBalancingModelConfig)
.where(
LoadBalancingModelConfig.tenant_id == self._tenant_id,
@ -1882,52 +2186,28 @@ class Migration:
.order_by(LoadBalancingModelConfig.id.asc())
)
if lock_rows:
stmt_loser = stmt_loser.with_for_update()
loser_lb_rows = session.execute(stmt_loser).scalars().all()
stmt = stmt.with_for_update()
if not loser_lb_rows:
return [], []
stmt_winner = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self._tenant_id,
LoadBalancingModelConfig.credential_id == winner_credential_id,
)
if lock_rows:
stmt_winner = stmt_winner.with_for_update()
winner_lb_rows = session.execute(stmt_winner).scalars().all()
winner_keys: set[tuple[str, str, str]] = {
(str(r.provider_name), str(r.model_name), str(r.model_type))
for r in winner_lb_rows
}
winner_credential = winner.row
winner_credential_id = str(winner_credential.id)
winner_credential_name = winner_credential.credential_name
winner_encrypted_config = winner_credential.encrypted_config
rewrite_plans: list[_LoadBalancingCredentialRewritePlan] = []
delete_plans: list[_LoadBalancingCredentialDeletePlan] = []
for lb_row in loser_lb_rows:
key = (str(lb_row.provider_name), str(lb_row.model_name), str(lb_row.model_type))
if key in winner_keys:
delete_plans.append(
_LoadBalancingCredentialDeletePlan(
row_id=str(lb_row.id),
old_credential_id=lb_row.credential_id,
winner_credential_id=winner_credential_id,
)
load_balancing_model_configs = session.execute(stmt).scalars().all()
for load_balancing_model_config in load_balancing_model_configs:
rewrite_plans.append(
_LoadBalancingCredentialRewritePlan(
row_id=str(load_balancing_model_config.id),
old_credential_id=load_balancing_model_config.credential_id,
old_name=load_balancing_model_config.name,
old_encrypted_config=load_balancing_model_config.encrypted_config,
new_credential_id=winner_credential_id,
new_name=winner_credential_name,
new_encrypted_config=winner_encrypted_config,
)
else:
rewrite_plans.append(
_LoadBalancingCredentialRewritePlan(
row_id=str(lb_row.id),
old_credential_id=lb_row.credential_id,
old_name=lb_row.name,
old_encrypted_config=lb_row.encrypted_config,
new_credential_id=winner_credential_id,
new_name=winner_credential_name,
new_encrypted_config=winner_encrypted_config,
)
)
return rewrite_plans, delete_plans
)
return rewrite_plans
def _configure_lock_timeout(self, session: Session) -> None:
dialect_name = session.get_bind().dialect.name
@ -1997,11 +2277,15 @@ class Migration:
"table_name": table_name,
"id": row_id,
"tx_id": tx_id,
"error": str(exc),
}
if business_key is not None:
attrs["business_key"] = self._business_key_to_dict(business_key)
self._log_event("lock_timeout_skipped", "Skipped transaction because row lock timed out.", attrs)
self._log_exception_event(
"lock_timeout_skipped",
"Skipped transaction because row lock timed out.",
attrs,
exc,
)
def _business_key_to_dict(self, business_key: _BusinessKey) -> dict[str, object]:
return cast(dict[str, object], asdict(business_key))
@ -2107,7 +2391,7 @@ class Migration:
},
)
except Exception as exc:
self._log_event(
self._log_exception_event(
"cache_delete_failed",
"Failed to delete related cache entry.",
{
@ -2118,8 +2402,8 @@ class Migration:
"cache_type": cache_plan.cache_type.value,
"tx_id": cache_plan.tx_id,
"business_key": self._business_key_to_dict(cache_plan.business_key),
"error": str(exc),
},
exc,
)
else:
self._log_event(
@ -2136,6 +2420,23 @@ class Migration:
},
)
def _log_exception_event(
self,
event: str,
message: str,
attrs: dict[str, object],
exc: BaseException,
) -> None:
self._log_event(
event,
message,
{
**attrs,
"error": str(exc),
"stacktrace": _format_exception_stacktrace(exc),
},
)
def _log_event(self, event: str, message: str, attrs: dict[str, object]) -> None:
record = {
"event": event,

View File

@ -44,7 +44,6 @@ class DirtyTenantFixture:
distinct_credential_id: str
provider_model_id: str
load_balancing_config_id: str
winner_load_balancing_config_id: str
provider_model_setting_id: str
tenant_default_model_id: str
embedding_provider_model_id: str
@ -318,7 +317,6 @@ def _seed_tenant(engine: Engine, *, tenant_id: str, provider_name: str) -> Dirty
"updated_at": now - timedelta(hours=3),
},
)
winner_load_balancing_config_id = str(uuid4())
conn.execute(
sa.text(
"""
@ -333,30 +331,20 @@ def _seed_tenant(engine: Engine, *, tenant_id: str, provider_name: str) -> Dirty
:load_balancing_config_id, :tenant_id, :provider_name, 'gpt-4o-mini', 'text-generation',
:lb_name, :loser_config, :loser_id, :credential_source_type,
:enabled, :created_at, :updated_at
),
(
:lb_winner_id, :tenant_id, :provider_name, 'gpt-4o-mini', 'text-generation',
:winner_name, :winner_config, :winner_cred_id, :credential_source_type,
:enabled, :created_at, :winner_updated_at
)
"""
),
{
"load_balancing_config_id": load_balancing_config_id,
"lb_winner_id": winner_load_balancing_config_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"lb_name": loser_credential_name,
"loser_config": loser_encrypted_config,
"loser_id": loser_credential_id,
"winner_name": f"{tenant_id}-winner-lb",
"winner_config": winner_encrypted_config,
"winner_cred_id": winner_credential_id,
"credential_source_type": CredentialSourceType.CUSTOM_MODEL.value,
"enabled": True,
"created_at": now - timedelta(days=2),
"updated_at": now - timedelta(hours=2),
"winner_updated_at": now - timedelta(hours=1),
},
)
@ -367,7 +355,6 @@ def _seed_tenant(engine: Engine, *, tenant_id: str, provider_name: str) -> Dirty
distinct_credential_id=distinct_credential_id,
provider_model_id=provider_model_id,
load_balancing_config_id=load_balancing_config_id,
winner_load_balancing_config_id=winner_load_balancing_config_id,
provider_model_setting_id=provider_model_setting_id,
tenant_default_model_id=tenant_default_model_id,
embedding_provider_model_id=embedding_provider_model_id,

View File

@ -2,7 +2,9 @@ from __future__ import annotations
import importlib
import io
import json
from collections.abc import Generator
from datetime import datetime, timedelta
import pytest
import sqlalchemy as sa
@ -15,6 +17,100 @@ from tests.helpers.legacy_model_type_migration import (
)
def _parse_json_lines(output: io.StringIO) -> list[dict[str, object]]:
return [json.loads(line) for line in output.getvalue().splitlines() if line.strip()]
def _json_key(value: object) -> str:
return json.dumps(value, sort_keys=True)
def _lb_processing_signatures(lines: list[dict[str, object]]) -> set[tuple[object, ...]]:
signatures: set[tuple[object, ...]] = set()
for line in lines:
attrs = line.get("attrs")
if not isinstance(attrs, dict):
continue
if attrs.get("table_name") != "load_balancing_model_configs":
continue
event = line.get("event")
if event == "row_updated":
signatures.add(
(
event,
attrs.get("id"),
_json_key(attrs.get("old_values")),
_json_key(attrs.get("new_values")),
)
)
elif event == "row_deleted":
signatures.add(
(
event,
attrs.get("id"),
attrs.get("merge_winner_id"),
)
)
elif event == "group_processed":
signatures.add(
(
event,
attrs.get("table_name"),
_json_key(attrs.get("business_key")),
tuple(attrs.get("group_row_ids", [])),
)
)
return signatures
def _insert_load_balancing_model_config(
engine: sa.Engine,
*,
row_id: str,
tenant_id: str,
provider_name: str,
model_name: str,
model_type: str,
name: str,
encrypted_config: str,
credential_id: str,
enabled: bool,
created_at: datetime,
updated_at: datetime,
) -> None:
with engine.begin() as conn:
conn.execute(
sa.text(
"""
INSERT INTO load_balancing_model_configs
(
id, tenant_id, provider_name, model_name, model_type, name,
encrypted_config, credential_id, credential_source_type, enabled, created_at, updated_at
)
VALUES
(
:id, :tenant_id, :provider_name, :model_name, :model_type, :name,
:encrypted_config, :credential_id, :credential_source_type, :enabled, :created_at, :updated_at
)
"""
),
{
"id": row_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"model_name": model_name,
"model_type": model_type,
"name": name,
"encrypted_config": encrypted_config,
"credential_id": credential_id,
"credential_source_type": "custom_model",
"enabled": enabled,
"created_at": created_at,
"updated_at": updated_at,
},
)
@pytest.fixture(scope="session")
def migration_module():
try:
@ -125,3 +221,188 @@ def test_legacy_model_type_migration_end_to_end_across_supported_backends(
for table_name in first_apply_state
}
assert second_apply_state == first_apply_state
def test_load_balancing_inherit_deduplication_is_applied_consistently_across_supported_backends(
migration_module,
container_engine: tuple[str, sa.Engine],
monkeypatch: pytest.MonkeyPatch,
) -> None:
_, engine = container_engine
helper_module = importlib.import_module("tests.helpers.legacy_model_type_migration")
helper_module.drop_minimal_legacy_model_type_schema(engine)
fixture = seed_legacy_model_type_dirty_data(engine)
tenant_id = fixture.primary.tenant_id
older_inherit_row_id = "00000000-0000-0000-0000-00000000ee01"
newer_inherit_row_id = "00000000-0000-0000-0000-00000000ee02"
canonical_non_inherit_row_id = "00000000-0000-0000-0000-00000000ee03"
created_at = datetime(2025, 1, 1, 8, 0, 0)
_insert_load_balancing_model_config(
engine,
row_id=older_inherit_row_id,
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="llm",
name="__inherit__",
encrypted_config='{"api_key":"older-inherit"}',
credential_id=fixture.primary.winner_credential_id,
enabled=True,
created_at=created_at,
updated_at=created_at + timedelta(minutes=15),
)
_insert_load_balancing_model_config(
engine,
row_id=newer_inherit_row_id,
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="text-generation",
name="__inherit__",
encrypted_config='{"api_key":"newer-inherit"}',
credential_id=fixture.primary.distinct_credential_id,
enabled=True,
created_at=created_at,
updated_at=created_at + timedelta(minutes=30),
)
_insert_load_balancing_model_config(
engine,
row_id=canonical_non_inherit_row_id,
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="llm",
name=f"{tenant_id}-second-shared",
encrypted_config='{"api_key":"non-inherit-canonical"}',
credential_id=fixture.primary.distinct_credential_id,
enabled=True,
created_at=created_at,
updated_at=created_at + timedelta(minutes=45),
)
before_dry_run = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id)
deleted_cache_keys: list[str] = []
def _record_delete(self) -> None:
deleted_cache_keys.append(self.cache_key)
monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete)
dry_run_output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
engine=engine,
apply=False,
output=dry_run_output,
tables=("load_balancing_model_configs",),
model_types=(migration_module.ModelType.LLM,),
tenant_ids=(tenant_id,),
).migrate()
after_dry_run = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id)
dry_run_lines = _parse_json_lines(dry_run_output)
dry_run_cache_events = [line["event"] for line in dry_run_lines if str(line.get("event")).startswith("cache_")]
dry_run_row_updates = {
str(attrs["id"])
for line in dry_run_lines
if line.get("event") == "row_updated"
and isinstance((attrs := line.get("attrs")), dict)
and attrs.get("table_name") == "load_balancing_model_configs"
}
dry_run_row_deletes = {
str(attrs["id"])
for line in dry_run_lines
if line.get("event") == "row_deleted"
and isinstance((attrs := line.get("attrs")), dict)
and attrs.get("table_name") == "load_balancing_model_configs"
}
dry_run_group_processed = [
attrs
for line in dry_run_lines
if line.get("event") == "group_processed"
and isinstance((attrs := line.get("attrs")), dict)
and attrs.get("table_name") == "load_balancing_model_configs"
]
assert after_dry_run == before_dry_run
assert deleted_cache_keys == []
assert dry_run_row_deletes == {older_inherit_row_id}
assert dry_run_row_updates == {
fixture.primary.load_balancing_config_id,
newer_inherit_row_id,
}
assert canonical_non_inherit_row_id not in dry_run_row_updates
assert "cache_delete_planned" in dry_run_cache_events
assert "cache_deleted" not in dry_run_cache_events
assert len(dry_run_group_processed) == 1
assert dry_run_group_processed[0]["table_name"] == "load_balancing_model_configs"
assert dry_run_group_processed[0]["business_key"] == {
"tenant_id": tenant_id,
"provider_name": "openai",
"model_name": "gpt-4o-mini",
"model_type": "llm",
}
assert set(dry_run_group_processed[0]["group_row_ids"]) == {
older_inherit_row_id,
newer_inherit_row_id,
}
apply_output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
engine=engine,
apply=True,
output=apply_output,
tables=("load_balancing_model_configs",),
model_types=(migration_module.ModelType.LLM,),
tenant_ids=(tenant_id,),
).migrate()
apply_lines = _parse_json_lines(apply_output)
apply_cache_events = [line["event"] for line in apply_lines if str(line.get("event")).startswith("cache_")]
apply_group_processed = [
attrs
for line in apply_lines
if line.get("event") == "group_processed"
and isinstance((attrs := line.get("attrs")), dict)
and attrs.get("table_name") == "load_balancing_model_configs"
]
assert _lb_processing_signatures(apply_lines) == _lb_processing_signatures(dry_run_lines)
assert "cache_deleted" in apply_cache_events
assert deleted_cache_keys
assert len(apply_group_processed) == len(dry_run_group_processed)
assert [
(
attrs["table_name"],
_json_key(attrs["business_key"]),
tuple(attrs["group_row_ids"]),
)
for attrs in apply_group_processed
] == [
(
attrs["table_name"],
_json_key(attrs["business_key"]),
tuple(attrs["group_row_ids"]),
)
for attrs in dry_run_group_processed
]
lb_rows = fetch_table_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id)
surviving_inherit_rows = [row for row in lb_rows if row["name"] == "__inherit__"]
surviving_non_inherit_rows = [row for row in lb_rows if row["name"] != "__inherit__"]
assert {str(row["id"]) for row in surviving_inherit_rows} == {newer_inherit_row_id}
assert surviving_inherit_rows[0]["model_type"] == "llm"
assert surviving_inherit_rows[0]["credential_id"] == fixture.primary.distinct_credential_id
assert {
str(row["id"])
for row in surviving_non_inherit_rows
if str(row["id"]) in {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id}
} == {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id}
assert all(
row["model_type"] == "llm"
for row in surviving_non_inherit_rows
if str(row["id"]) in {fixture.primary.load_balancing_config_id, canonical_non_inherit_row_id}
)
assert count_rows(engine, "load_balancing_model_configs", tenant_id=tenant_id) == len(before_dry_run) - 1

View File

@ -2,12 +2,12 @@ from collections.abc import Iterator
import pytest
from dify_agent.client import DifyAgentHTTPError, DifyAgentStreamError, DifyAgentTimeoutError, DifyAgentValidationError
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
from dify_agent.protocol import (
CancelRunRequest,
CancelRunResponse,
CreateRunRequest,
CreateRunResponse,
ExecutionContext,
RunEvent,
RunStartedEvent,
RunStatusResponse,
@ -29,11 +29,12 @@ def _request():
return AgentBackendRunRequestBuilder().build_for_workflow_node(
AgentBackendWorkflowNodeRunInput(
model=AgentBackendModelConfig(
tenant_id="tenant-1",
plugin_id="langgenius/openai",
model_provider="openai",
model="gpt-test",
),
execution_context=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"),
execution_context=ExecutionContext(tenant_id="tenant-1", invoke_from="workflow_run"),
workflow_node_job_prompt="Do the task.",
user_prompt="hello",
)

View File

@ -1,4 +1,4 @@
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
from dify_agent.protocol import ExecutionContext
from clients.agent_backend import (
AgentBackendModelConfig,
@ -13,11 +13,12 @@ def _request():
return AgentBackendRunRequestBuilder().build_for_workflow_node(
AgentBackendWorkflowNodeRunInput(
model=AgentBackendModelConfig(
tenant_id="tenant-1",
plugin_id="langgenius/openai",
model_provider="openai",
model="gpt-test",
),
execution_context=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"),
execution_context=ExecutionContext(tenant_id="tenant-1", invoke_from="workflow_run"),
workflow_node_job_prompt="Do the task.",
user_prompt="hello",
)

View File

@ -1,19 +1,18 @@
import pytest
from agenton.layers import ExitIntent
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID
from dify_agent.layers.dify_plugin import DIFY_PLUGIN_LLM_LAYER_TYPE_ID
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
from dify_agent.layers.dify_plugin import DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID
from dify_agent.protocol import (
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
CreateRunRequest,
ExecutionContext,
)
from pydantic import ValidationError
from clients.agent_backend import (
AGENT_SOUL_PROMPT_LAYER_ID,
DIFY_EXECUTION_CONTEXT_LAYER_ID,
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
AgentBackendModelConfig,
@ -27,14 +26,15 @@ from clients.agent_backend import (
def _run_input() -> AgentBackendWorkflowNodeRunInput:
return AgentBackendWorkflowNodeRunInput(
model=AgentBackendModelConfig(
tenant_id="tenant-1",
plugin_id="langgenius/openai",
user_id="user-1",
model_provider="openai",
model="gpt-test",
credentials={"api_key": "secret-key"},
),
execution_context=DifyExecutionContextLayerConfig(
execution_context=ExecutionContext(
tenant_id="tenant-1",
user_id="user-1",
workflow_id="workflow-1",
workflow_run_id="workflow-run-1",
node_id="node-1",
@ -64,11 +64,13 @@ def test_request_builder_outputs_dify_agent_create_run_request():
AGENT_SOUL_PROMPT_LAYER_ID,
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
DIFY_EXECUTION_CONTEXT_LAYER_ID,
"plugin",
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
]
assert request.on_exit.default is ExitIntent.DELETE
assert request.execution_context is not None
assert request.execution_context.node_execution_id == "node-execution-1"
assert request.idempotency_key == "workflow-run-1:node-execution-1"
assert request.metadata == {"workflow_id": "workflow-1", "node_id": "node-1"}
@ -92,11 +94,9 @@ def test_request_builder_sets_model_and_output_layer_contract_ids():
request = AgentBackendRunRequestBuilder().build_for_workflow_node(_run_input())
layers = {layer.name: layer for layer in request.composition.layers}
assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID].type == DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID
assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID].config.user_id == "user-1"
assert layers["plugin"].type == DIFY_PLUGIN_LAYER_TYPE_ID
assert layers[DIFY_AGENT_MODEL_LAYER_ID].type == DIFY_PLUGIN_LLM_LAYER_TYPE_ID
assert layers[DIFY_AGENT_MODEL_LAYER_ID].config.plugin_id == "langgenius/openai"
assert layers[DIFY_AGENT_MODEL_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}
assert layers[DIFY_AGENT_MODEL_LAYER_ID].deps == {"plugin": "plugin"}
assert layers[DIFY_AGENT_OUTPUT_LAYER_ID].type == DIFY_OUTPUT_LAYER_TYPE_ID
@ -113,11 +113,12 @@ def test_request_builder_rejects_blank_prompts():
with pytest.raises(ValidationError):
AgentBackendWorkflowNodeRunInput(
model=AgentBackendModelConfig(
tenant_id="tenant-1",
plugin_id="langgenius/openai",
model_provider="openai",
model="gpt-test",
),
execution_context=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"),
execution_context=ExecutionContext(tenant_id="tenant-1", invoke_from="workflow_run"),
workflow_node_job_prompt=" ",
user_prompt="hello",
)

View File

@ -17,6 +17,7 @@ from click.testing import CliRunner
from sqlalchemy.exc import OperationalError
from graphon.model_runtime.entities.model_entities import ModelType
from models.account import Tenant
from models.enums import CredentialSourceType
from models.provider import ProviderModel
from tests.helpers.legacy_model_type_migration import (
@ -24,6 +25,7 @@ from tests.helpers.legacy_model_type_migration import (
LEGACY_TO_CANONICAL,
assert_tenant_rows_use_only_canonical_model_types,
count_rows,
create_minimal_legacy_model_type_schema,
fetch_table_rows,
seed_legacy_model_type_dirty_data,
snapshot_legacy_model_type_state,
@ -117,6 +119,28 @@ def _collect_processing_signatures(lines: list[dict[str, object]]) -> set[tuple[
return signatures
def _cache_event_row_ids(
lines: list[dict[str, object]],
*,
table_name: str,
row_ids: set[str],
event_name: str,
) -> set[str]:
matching_row_ids: set[str] = set()
for line in lines:
if line.get("event") != event_name:
continue
attrs = line.get("attrs")
if not isinstance(attrs, dict):
continue
if attrs.get("table_name") != table_name:
continue
row_id = str(attrs.get("id"))
if row_id in row_ids:
matching_row_ids.add(row_id)
return matching_row_ids
def _patch_batch_size(
monkeypatch: pytest.MonkeyPatch,
migration_module,
@ -174,6 +198,18 @@ def _insert_provider_model(
)
def _insert_tenant(engine: sa.Engine, *, tenant_id: str) -> None:
with engine.begin() as conn:
conn.execute(
Tenant.__table__.insert().values(
id=tenant_id,
name=f"Tenant {tenant_id}",
plan="basic",
status="normal",
)
)
def _insert_tenant_default_model(
engine: sa.Engine,
*,
@ -487,7 +523,7 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
migration_module,
sqlite_engine: sa.Engine,
) -> None:
seen_runs: list[dict[str, object]] = []
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
class FakeMigration:
def __init__(
@ -500,18 +536,12 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
model_types: tuple[ModelType, ...],
orm_models: tuple[type[object], ...],
) -> None:
seen_runs.append(
{
"tenant_id": tenant_id,
"engine": engine,
"apply": apply,
"model_types": model_types,
"table_names": tuple(model.__table__.name for model in orm_models),
}
)
assert engine is sqlite_engine
assert apply is False
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
def run(self) -> None:
seen_runs.append({"run": True})
return None
monkeypatch = pytest.MonkeyPatch()
try:
@ -520,7 +550,7 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
engine=sqlite_engine,
apply=False,
concurrency=1,
tables=("provider_models",),
tables=("provider_models", "tenant_default_models"),
model_types=(ModelType.LLM,),
tenant_ids=("tenant-alpha", "tenant-beta"),
)
@ -529,11 +559,267 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
finally:
monkeypatch.undo()
init_calls = [call for call in seen_runs if "tenant_id" in call]
assert [call["tenant_id"] for call in init_calls] == ["tenant-alpha", "tenant-beta"]
for call in init_calls:
assert tuple(cast(tuple[str, ...], call["table_names"])) == ("provider_models",)
assert call["model_types"] == (ModelType.LLM,)
assert seen_runs == [
("tenant-alpha", ("provider_models", "tenant_default_models"), (ModelType.LLM,)),
("tenant-beta", ("provider_models", "tenant_default_models"), (ModelType.LLM,)),
]
def test_service_migrate_without_tenant_ids_discovers_tenants_per_selected_table_without_querying_tenants(
migration_module,
sqlite_engine: sa.Engine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
create_minimal_legacy_model_type_schema(sqlite_engine)
provider_tenant_id = "00000000-0000-0000-0000-000000000111"
default_tenant_id = "00000000-0000-0000-0000-000000000222"
empty_tenant_id = "00000000-0000-0000-0000-000000000333"
for tenant_id in (provider_tenant_id, default_tenant_id, empty_tenant_id):
_insert_tenant(sqlite_engine, tenant_id=tenant_id)
created_at = datetime(2025, 1, 1, 12, 0, 0)
updated_at = created_at + timedelta(minutes=1)
_insert_provider_model(
sqlite_engine,
row_id="10000000-0000-0000-0000-000000000111",
tenant_id=provider_tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="text-generation",
credential_id=None,
created_at=created_at,
updated_at=updated_at,
)
_insert_tenant_default_model(
sqlite_engine,
row_id="20000000-0000-0000-0000-000000000222",
tenant_id=default_tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="text-generation",
created_at=created_at,
updated_at=updated_at,
)
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
executed_sql: list[str] = []
class FakeMigration:
def __init__(
self,
*,
tenant_id: str,
engine: sa.Engine,
apply: bool,
output: io.TextIOBase,
model_types: tuple[ModelType, ...],
orm_models: tuple[type[object], ...],
) -> None:
assert engine is sqlite_engine
assert apply is False
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
def run(self) -> None:
return None
def _record_sql(
conn: sa.engine.Connection,
cursor: object,
statement: str,
parameters: object,
context: object,
executemany: bool,
) -> None:
del conn, cursor, parameters, context, executemany
executed_sql.append(statement)
sa.event.listen(sqlite_engine, "before_cursor_execute", _record_sql)
try:
monkeypatch.setattr(migration_module, "Migration", FakeMigration)
service = migration_module.LegacyModelTypeMigrationService(
engine=sqlite_engine,
apply=False,
tables=("provider_models", "tenant_default_models"),
model_types=(ModelType.LLM,),
)
service.migrate()
finally:
sa.event.remove(sqlite_engine, "before_cursor_execute", _record_sql)
assert seen_runs == [
(provider_tenant_id, ("provider_models",), (ModelType.LLM,)),
(default_tenant_id, ("tenant_default_models",), (ModelType.LLM,)),
]
normalized_statements = [" ".join(statement.lower().split()) for statement in executed_sql]
discovery_statements = [statement for statement in normalized_statements if statement.startswith("select")]
table_names = ("provider_models", "tenant_default_models")
table_discovery_statements = [
statement
for statement in discovery_statements
if any(f" from {table_name} " in f" {statement} " for table_name in table_names)
]
assert [statement for statement in discovery_statements if " from tenants " in f" {statement} "] == []
assert [statement for statement in discovery_statements if " union " in f" {statement} "] == []
assert [
next(table_name for table_name in table_names if f" from {table_name} " in f" {statement} ")
for statement in table_discovery_statements
] == list(table_names)
def test_service_migrate_without_tenant_ids_filters_provider_model_tenants_by_selected_model_types(
migration_module,
sqlite_engine: sa.Engine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
create_minimal_legacy_model_type_schema(sqlite_engine)
llm_tenant_id = "00000000-0000-0000-0000-000000000411"
embedding_tenant_id = "00000000-0000-0000-0000-000000000422"
empty_tenant_id = "00000000-0000-0000-0000-000000000433"
for tenant_id in (llm_tenant_id, embedding_tenant_id, empty_tenant_id):
_insert_tenant(sqlite_engine, tenant_id=tenant_id)
created_at = datetime(2025, 1, 2, 12, 0, 0)
updated_at = created_at + timedelta(minutes=1)
_insert_provider_model(
sqlite_engine,
row_id="30000000-0000-0000-0000-000000000411",
tenant_id=llm_tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="text-generation",
credential_id=None,
created_at=created_at,
updated_at=updated_at,
)
_insert_provider_model(
sqlite_engine,
row_id="30000000-0000-0000-0000-000000000422",
tenant_id=embedding_tenant_id,
provider_name="openai",
model_name="text-embedding-3-large",
model_type="embeddings",
credential_id=None,
created_at=created_at,
updated_at=updated_at,
)
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
class FakeMigration:
def __init__(
self,
*,
tenant_id: str,
engine: sa.Engine,
apply: bool,
output: io.TextIOBase,
model_types: tuple[ModelType, ...],
orm_models: tuple[type[object], ...],
) -> None:
assert engine is sqlite_engine
assert apply is False
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
def run(self) -> None:
return None
monkeypatch.setattr(migration_module, "Migration", FakeMigration)
service = migration_module.LegacyModelTypeMigrationService(
engine=sqlite_engine,
apply=False,
tables=("provider_models",),
model_types=(ModelType.LLM,),
)
service.migrate()
assert seen_runs == [
(llm_tenant_id, ("provider_models",), (ModelType.LLM,)),
]
def test_service_migrate_without_tenant_ids_discovers_all_load_balancing_tenants_for_simpler_table_scoped_query(
migration_module,
sqlite_engine: sa.Engine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
create_minimal_legacy_model_type_schema(sqlite_engine)
inherit_llm_tenant_id = "00000000-0000-0000-0000-000000000511"
inherit_embedding_tenant_id = "00000000-0000-0000-0000-000000000522"
empty_tenant_id = "00000000-0000-0000-0000-000000000533"
for tenant_id in (inherit_llm_tenant_id, inherit_embedding_tenant_id, empty_tenant_id):
_insert_tenant(sqlite_engine, tenant_id=tenant_id)
created_at = datetime(2025, 1, 3, 12, 0, 0)
updated_at = created_at + timedelta(minutes=1)
_insert_load_balancing_model_config(
sqlite_engine,
row_id="40000000-0000-0000-0000-000000000511",
tenant_id=inherit_llm_tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type=ModelType.LLM.value,
name="__inherit__",
encrypted_config=json.dumps({"api_key": "inherit-llm"}),
credential_id="50000000-0000-0000-0000-000000000511",
enabled=True,
created_at=created_at,
updated_at=updated_at,
)
_insert_load_balancing_model_config(
sqlite_engine,
row_id="40000000-0000-0000-0000-000000000522",
tenant_id=inherit_embedding_tenant_id,
provider_name="openai",
model_name="text-embedding-3-large",
model_type=ModelType.TEXT_EMBEDDING.value,
name="__inherit__",
encrypted_config=json.dumps({"api_key": "inherit-embedding"}),
credential_id="50000000-0000-0000-0000-000000000522",
enabled=True,
created_at=created_at,
updated_at=updated_at,
)
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
class FakeMigration:
def __init__(
self,
*,
tenant_id: str,
engine: sa.Engine,
apply: bool,
output: io.TextIOBase,
model_types: tuple[ModelType, ...],
orm_models: tuple[type[object], ...],
) -> None:
assert engine is sqlite_engine
assert apply is False
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
def run(self) -> None:
return None
monkeypatch.setattr(migration_module, "Migration", FakeMigration)
# Load-balancing tenant discovery is a deliberate exception: it scans the
# whole table so the discovery query stays easy to understand, even when
# the scheduled tenant set is wider than the selected model types.
service = migration_module.LegacyModelTypeMigrationService(
engine=sqlite_engine,
apply=False,
tables=("load_balancing_model_configs",),
model_types=(ModelType.LLM,),
)
service.migrate()
assert seen_runs == [
(inherit_llm_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)),
(inherit_embedding_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)),
]
def test_service_migrate_with_concurrency_greater_than_one_runs_tenants_in_parallel_without_changing_migration_scope(
@ -942,6 +1228,71 @@ def test_is_lock_timeout_error_prefers_structured_backend_codes(
assert migration._is_lock_timeout_error(exc) is expected
def test_process_load_balancing_model_config_row_logs_stacktrace_for_lock_timeout(
migration_module,
sqlite_engine: sa.Engine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
output = io.StringIO()
migration = migration_module.Migration(
tenant_id="tenant-1",
engine=sqlite_engine,
apply=True,
output=output,
model_types=(ModelType.LLM,),
orm_models=(migration_module.LoadBalancingModelConfig,),
)
candidate = migration_module._RowWithRawModelType(
row=SimpleNamespace(id="lb-row-1"),
raw_model_type="text-generation",
canonical_model_type=ModelType.LLM,
)
lock_timeout_exc = OperationalError("SELECT 1", {}, SimpleNamespace(pgcode="55P03"))
class _FakeBeginContext:
def __enter__(self) -> None:
return None
def __exit__(self, exc_type, exc, tb) -> bool:
return False
class _FakeSession:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb) -> bool:
return False
def begin(self) -> _FakeBeginContext:
return _FakeBeginContext()
def _fake_session_factory(engine: sa.Engine) -> _FakeSession:
return _FakeSession()
def _fake_reload(self, session, original_candidate, *, lock_rows: bool):
raise lock_timeout_exc
monkeypatch.setattr(migration_module, "_session_factory", _fake_session_factory)
monkeypatch.setattr(migration_module.Migration, "_configure_lock_timeout", lambda self, session: None)
monkeypatch.setattr(
migration_module.Migration,
"_reload_load_balancing_model_config_candidate",
_fake_reload,
)
migration._process_load_balancing_model_config_row(candidate)
lines = _parse_json_lines(output)
assert len(lines) == 1
assert lines[0]["event"] == "lock_timeout_skipped"
attrs = cast(dict[str, object], lines[0]["attrs"])
assert attrs["table_name"] == "load_balancing_model_configs"
assert attrs["id"] == "lb-row-1"
assert attrs["error"] == str(lock_timeout_exc)
assert isinstance(attrs["stacktrace"], str)
assert "OperationalError" in attrs["stacktrace"]
def test_process_load_balancing_model_config_row_logs_update_after_sql_execution(
migration_module,
sqlite_engine: sa.Engine,
@ -1024,6 +1375,41 @@ def test_process_load_balancing_model_config_row_logs_update_after_sql_execution
]
def test_load_balancing_model_config_cache_delete_failure_logs_stacktrace(
migration_module,
sqlite_engine: sa.Engine,
dirty_fixture,
monkeypatch: pytest.MonkeyPatch,
) -> None:
def _raise_delete_failure(self) -> None:
raise RuntimeError("cache delete boom")
monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _raise_delete_failure)
output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
engine=sqlite_engine,
apply=True,
output=output,
tables=("load_balancing_model_configs",),
model_types=(ModelType.LLM,),
tenant_ids=(dirty_fixture.primary.tenant_id,),
).migrate()
failed_events = [
cast(dict[str, object], line["attrs"])
for line in _parse_json_lines(output)
if line.get("event") == "cache_delete_failed"
and isinstance(line.get("attrs"), dict)
and cast(dict[str, object], line["attrs"]).get("table_name") == "load_balancing_model_configs"
]
assert len(failed_events) == 1
assert failed_events[0]["error"] == "cache delete boom"
assert isinstance(failed_events[0]["stacktrace"], str)
assert "RuntimeError: cache delete boom" in cast(str, failed_events[0]["stacktrace"])
def test_group_completed_logs_exist_for_all_grouped_tables_and_use_canonical_model_type(
migration_module,
sqlite_engine: sa.Engine,
@ -1188,28 +1574,44 @@ def test_provider_model_settings_group_crossing_batches_is_completed_once_with_a
}
def test_load_balancing_model_configs_are_canonicalized_row_by_row_without_group_business_key_semantics(
def test_load_balancing_inherit_rows_are_deduplicated_by_normalized_model_type_before_canonicalization(
migration_module,
sqlite_engine: sa.Engine,
dirty_fixture,
monkeypatch: pytest.MonkeyPatch,
) -> None:
inserted_row_id = "00000000-0000-0000-0000-00000000dd01"
older_canonical_row_id = "00000000-0000-0000-0000-00000000dd01"
newer_legacy_row_id = "00000000-0000-0000-0000-00000000dd02"
created_at = datetime(2025, 1, 1, 8, 0, 0)
updated_at = created_at + timedelta(minutes=15)
older_updated_at = created_at + timedelta(minutes=15)
newer_updated_at = created_at + timedelta(minutes=30)
_insert_load_balancing_model_config(
sqlite_engine,
row_id=inserted_row_id,
row_id=older_canonical_row_id,
tenant_id=dirty_fixture.primary.tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type=ModelType.LLM.value,
name="__inherit__",
encrypted_config='{"api_key":"older-inherit"}',
credential_id=dirty_fixture.primary.winner_credential_id,
enabled=True,
created_at=created_at,
updated_at=older_updated_at,
)
_insert_load_balancing_model_config(
sqlite_engine,
row_id=newer_legacy_row_id,
tenant_id=dirty_fixture.primary.tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="text-generation",
name=dirty_fixture.primary.loser_credential_name,
encrypted_config='{"api_key":"second-lb"}',
name="__inherit__",
encrypted_config='{"api_key":"newer-inherit"}',
credential_id=dirty_fixture.primary.distinct_credential_id,
enabled=True,
created_at=created_at,
updated_at=updated_at,
updated_at=newer_updated_at,
)
deleted_cache_keys: list[str] = []
@ -1221,11 +1623,7 @@ def test_load_balancing_model_configs_are_canonicalized_row_by_row_without_group
tenant_id = dirty_fixture.primary.tenant_id
table_name = "load_balancing_model_configs"
expected_row_ids = {
dirty_fixture.primary.load_balancing_config_id,
dirty_fixture.primary.winner_load_balancing_config_id,
inserted_row_id,
}
expected_row_ids = {older_canonical_row_id, newer_legacy_row_id}
dry_run_output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
@ -1238,44 +1636,79 @@ def test_load_balancing_model_configs_are_canonicalized_row_by_row_without_group
).migrate()
dry_run_lines = _parse_json_lines(dry_run_output)
dry_run_signatures = {
signature
for signature in _collect_processing_signatures(dry_run_lines)
if signature[1] == table_name and signature[2] in expected_row_ids
}
dry_run_row_updates = [
cast(dict[str, object], line["attrs"])
for line in dry_run_lines
if line.get("event") == "row_updated"
and isinstance(line.get("attrs"), dict)
and cast(dict[str, object], line["attrs"]).get("table_name") == table_name
and str(cast(dict[str, object], line["attrs"]).get("id")) in expected_row_ids
]
assert len(dry_run_row_updates) == 3
assert {str(attrs["id"]) for attrs in dry_run_row_updates} == expected_row_ids
assert all(attrs.get("old_values") == {"model_type": "text-generation"} for attrs in dry_run_row_updates)
assert all(attrs.get("new_values") == {"model_type": ModelType.LLM.value} for attrs in dry_run_row_updates)
assert len(dry_run_row_updates) == 1
assert str(dry_run_row_updates[0]["id"]) == newer_legacy_row_id
assert dry_run_row_updates[0]["old_values"] == {"model_type": "text-generation"}
assert dry_run_row_updates[0]["new_values"] == {"model_type": ModelType.LLM.value}
assert all("rewrite_source" not in attrs for attrs in dry_run_row_updates)
dry_run_group_processed = [
dry_run_row_deletes = [
cast(dict[str, object], line["attrs"])
for line in dry_run_lines
if line.get("event") == "group_processed"
if line.get("event") == "row_deleted"
and isinstance(line.get("attrs"), dict)
and cast(dict[str, object], line["attrs"]).get("table_name") == table_name
and str(cast(dict[str, object], line["attrs"]).get("id")) in expected_row_ids
]
assert dry_run_group_processed == []
assert len(dry_run_row_deletes) == 1
assert dry_run_row_deletes[0]["business_key"] == {
"tenant_id": tenant_id,
"provider_name": "openai",
"model_name": "gpt-4o-mini",
"model_type": ModelType.LLM.value,
}
assert dry_run_row_deletes[0]["merge_winner_id"] == newer_legacy_row_id
assert dry_run_row_deletes[0]["row"] == {
"id": older_canonical_row_id,
"tenant_id": tenant_id,
"provider_name": "openai",
"model_name": "gpt-4o-mini",
"model_type": ModelType.LLM.value,
"name": "__inherit__",
"encrypted_config": {"api_key": "older-inherit"},
"credential_id": dirty_fixture.primary.winner_credential_id,
"credential_source_type": CredentialSourceType.CUSTOM_MODEL.value,
"enabled": True,
"created_at": created_at.isoformat(),
"updated_at": older_updated_at.isoformat(),
}
dry_run_cache_plans = [
cast(dict[str, object], line["attrs"])
for line in dry_run_lines
if line.get("event") == "cache_delete_planned"
dry_run_deleted_index = next(
index
for index, line in enumerate(dry_run_lines)
if line.get("event") == "row_deleted"
and isinstance(line.get("attrs"), dict)
and cast(dict[str, object], line["attrs"]).get("table_name") == table_name
]
assert len(dry_run_cache_plans) == 3
assert {str(attrs["id"]) for attrs in dry_run_cache_plans} == expected_row_ids
and cast(dict[str, object], line["attrs"]).get("id") == older_canonical_row_id
)
dry_run_updated_index = next(
index
for index, line in enumerate(dry_run_lines)
if line.get("event") == "row_updated"
and isinstance(line.get("attrs"), dict)
and cast(dict[str, object], line["attrs"]).get("id") == newer_legacy_row_id
)
assert dry_run_deleted_index < dry_run_updated_index
dry_run_business_keys = [
_json_key(business_key)
for attrs in [*dry_run_row_updates, *dry_run_cache_plans]
if isinstance((business_key := attrs.get("business_key")), dict)
]
assert len(set(dry_run_business_keys)) == len(dry_run_business_keys)
dry_run_cache_plan_ids = _cache_event_row_ids(
dry_run_lines,
table_name=table_name,
row_ids=expected_row_ids,
event_name="cache_delete_planned",
)
assert newer_legacy_row_id in dry_run_cache_plan_ids
apply_output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
@ -1288,47 +1721,100 @@ def test_load_balancing_model_configs_are_canonicalized_row_by_row_without_group
).migrate()
apply_lines = _parse_json_lines(apply_output)
apply_signatures = {
signature
for signature in _collect_processing_signatures(apply_lines)
if signature[1] == table_name and signature[2] in expected_row_ids
}
apply_row_updates = [
cast(dict[str, object], line["attrs"])
for line in apply_lines
if line.get("event") == "row_updated"
and isinstance(line.get("attrs"), dict)
and cast(dict[str, object], line["attrs"]).get("table_name") == table_name
and str(cast(dict[str, object], line["attrs"]).get("id")) in expected_row_ids
]
assert len(apply_row_updates) == 3
assert {str(attrs["id"]) for attrs in apply_row_updates} == expected_row_ids
assert len(apply_row_updates) == 1
assert str(apply_row_updates[0]["id"]) == newer_legacy_row_id
assert apply_signatures == dry_run_signatures
apply_group_processed = [
cast(dict[str, object], line["attrs"])
for line in apply_lines
if line.get("event") == "group_processed"
and isinstance(line.get("attrs"), dict)
and cast(dict[str, object], line["attrs"]).get("table_name") == table_name
]
assert apply_group_processed == []
apply_cache_deletes = [
cast(dict[str, object], line["attrs"])
for line in apply_lines
if line.get("event") == "cache_deleted"
and isinstance(line.get("attrs"), dict)
and cast(dict[str, object], line["attrs"]).get("table_name") == table_name
]
assert len(apply_cache_deletes) == 3
assert {str(attrs["id"]) for attrs in apply_cache_deletes} == expected_row_ids
assert len(deleted_cache_keys) == 3
apply_business_keys = [
_json_key(business_key)
for attrs in [*apply_row_updates, *apply_cache_deletes]
if isinstance((business_key := attrs.get("business_key")), dict)
]
assert len(set(apply_business_keys)) == len(apply_business_keys)
apply_cache_delete_ids = _cache_event_row_ids(
apply_lines,
table_name=table_name,
row_ids=expected_row_ids,
event_name="cache_deleted",
)
assert apply_cache_delete_ids == dry_run_cache_plan_ids
assert deleted_cache_keys
lb_rows = fetch_table_rows(sqlite_engine, table_name, tenant_id=tenant_id)
migrated_rows = [row for row in lb_rows if str(row["id"]) in expected_row_ids]
assert len(migrated_rows) == 3
assert all(row["model_type"] == ModelType.LLM.value for row in migrated_rows)
surviving_rows = [row for row in lb_rows if str(row["id"]) in expected_row_ids]
assert len(surviving_rows) == 1
surviving_row = surviving_rows[0]
assert surviving_row["id"] == newer_legacy_row_id
assert surviving_row["tenant_id"] == tenant_id
assert surviving_row["provider_name"] == "openai"
assert surviving_row["model_name"] == "gpt-4o-mini"
assert surviving_row["model_type"] == ModelType.LLM.value
assert surviving_row["name"] == "__inherit__"
assert surviving_row["encrypted_config"] == '{"api_key":"newer-inherit"}'
assert surviving_row["credential_id"] == dirty_fixture.primary.distinct_credential_id
assert surviving_row["credential_source_type"] == CredentialSourceType.CUSTOM_MODEL.value
def test_load_balancing_non_inherit_rows_do_not_participate_in_normalized_model_type_deduplication(
migration_module,
sqlite_engine: sa.Engine,
dirty_fixture,
) -> None:
inserted_row_id = "00000000-0000-0000-0000-00000000dd03"
created_at = datetime(2025, 1, 1, 8, 0, 0)
updated_at = created_at + timedelta(minutes=15)
_insert_load_balancing_model_config(
sqlite_engine,
row_id=inserted_row_id,
tenant_id=dirty_fixture.primary.tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type=ModelType.LLM.value,
name=dirty_fixture.primary.loser_credential_name,
encrypted_config='{"api_key":"second-lb"}',
credential_id=dirty_fixture.primary.distinct_credential_id,
enabled=True,
created_at=created_at,
updated_at=updated_at,
)
output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
engine=sqlite_engine,
apply=True,
output=output,
tables=("load_balancing_model_configs",),
model_types=(ModelType.LLM,),
tenant_ids=(dirty_fixture.primary.tenant_id,),
).migrate()
lines = _parse_json_lines(output)
row_deleted_events = [
cast(dict[str, object], line["attrs"])
for line in lines
if line.get("event") == "row_deleted"
and isinstance(line.get("attrs"), dict)
and cast(dict[str, object], line["attrs"]).get("table_name") == "load_balancing_model_configs"
]
assert row_deleted_events == []
lb_rows = fetch_table_rows(
sqlite_engine,
"load_balancing_model_configs",
tenant_id=dirty_fixture.primary.tenant_id,
)
matching_rows = [
row for row in lb_rows if str(row["id"]) in {dirty_fixture.primary.load_balancing_config_id, inserted_row_id}
]
assert len(matching_rows) == 2
assert all(row["model_type"] == ModelType.LLM.value for row in matching_rows)
def test_migration_apply_updates_all_five_tables_and_rewrites_credential_references(
@ -1364,12 +1850,10 @@ def test_migration_apply_updates_all_five_tables_and_rewrites_credential_referen
assert provider_model_row["credential_id"] == dirty_fixture.primary.winner_credential_id
lb_rows = fetch_table_rows(sqlite_engine, "load_balancing_model_configs", tenant_id=dirty_fixture.primary.tenant_id)
lb_ids = {str(r["id"]) for r in lb_rows}
# The loser LB row is deleted during credential dedup (winner already has a row for the same key).
assert str(dirty_fixture.primary.load_balancing_config_id) not in lb_ids
winner_lb_row = next(row for row in lb_rows if row["id"] == dirty_fixture.primary.winner_load_balancing_config_id)
assert winner_lb_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"]
assert winner_lb_row["credential_id"] == dirty_fixture.primary.winner_credential_id
lb_row = next(row for row in lb_rows if row["id"] == dirty_fixture.primary.load_balancing_config_id)
assert lb_row["model_type"] == LEGACY_TO_CANONICAL["text-generation"]
assert lb_row["credential_id"] == dirty_fixture.primary.winner_credential_id
assert lb_row["encrypted_config"] == dirty_fixture.primary.winner_encrypted_config
credential_rows = fetch_table_rows(
sqlite_engine, "provider_model_credentials", tenant_id=dirty_fixture.primary.tenant_id
@ -1539,70 +2023,3 @@ def test_migration_is_idempotent_on_second_apply(
after_second = snapshot_legacy_model_type_state(sqlite_engine)
assert after_second == after_first
def test_lb_loser_row_deleted_when_winner_has_same_model(
sqlite_engine, dirty_fixture, migration_module
) -> None:
"""Loser LB row must be deleted when winner credential already has an LB row
for the same (provider_name, model_name, model_type)."""
LegacyModelTypeMigrationService = migration_module.LegacyModelTypeMigrationService
primary = dirty_fixture.primary
output = io.StringIO()
LegacyModelTypeMigrationService(
engine=sqlite_engine,
apply=True,
output=output,
).migrate()
lb_rows = fetch_table_rows(sqlite_engine, "load_balancing_model_configs", tenant_id=primary.tenant_id)
lb_ids = {str(r["id"]) for r in lb_rows}
assert str(primary.load_balancing_config_id) not in lb_ids, (
"loser LB row should have been deleted (winner already had matching row)"
)
assert str(primary.winner_load_balancing_config_id) in lb_ids, (
"winner LB row must survive"
)
llm_lb_rows = [r for r in lb_rows if str(r["model_type"]) == "llm" and str(r["model_name"]) == "gpt-4o-mini"]
assert len(llm_lb_rows) == 1, f"expected 1 LB row for llm gpt-4o-mini, got {llm_lb_rows}"
def test_lb_loser_deletion_logged_in_dry_run(
sqlite_engine, dirty_fixture, migration_module
) -> None:
"""Dry run must log a row_deleted event (not row_updated/rewrite) for the loser LB row."""
LegacyModelTypeMigrationService = migration_module.LegacyModelTypeMigrationService
primary = dirty_fixture.primary
output = io.StringIO()
LegacyModelTypeMigrationService(
engine=sqlite_engine,
apply=False,
output=output,
).migrate()
events = _parse_json_lines(output)
deletion_events = [
e for e in events
if e.get("event") == "row_deleted"
and isinstance(e.get("attrs"), dict)
and e["attrs"].get("table_name") == "load_balancing_model_configs"
and str(e["attrs"].get("id")) == str(primary.load_balancing_config_id)
]
assert deletion_events, "expected a row_deleted event for the loser LB row in dry-run"
credential_rewrite_events = [
e for e in events
if e.get("event") == "row_updated"
and isinstance(e.get("attrs"), dict)
and e["attrs"].get("table_name") == "load_balancing_model_configs"
and str(e["attrs"].get("id")) == str(primary.load_balancing_config_id)
and isinstance(e["attrs"].get("rewrite_source"), dict)
and e["attrs"]["rewrite_source"].get("rewrite_kind") == "credential_reference"
]
assert not credential_rewrite_events, (
"loser LB row must not be logged as a credential_reference rewrite when winner has matching row"
)

View File

@ -61,20 +61,79 @@ class TestRepack:
class TestUpdatePromptTool:
def test_replaces_prompt_tool_parameters_with_tool_schema(self, runner, mocker: MockerFixture):
def build_param(self, mocker: MockerFixture, **kwargs):
p = mocker.MagicMock()
p.form = kwargs.get("form")
mock_type = mocker.MagicMock()
mock_type.as_normal_type.return_value = "string"
p.type = mock_type
p.name = kwargs.get("name", "p1")
p.llm_description = "desc"
p.input_schema = kwargs.get("input_schema")
p.options = kwargs.get("options")
p.required = kwargs.get("required", False)
return p
def test_skip_non_llm(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock()
schema = {
"type": "object",
"properties": {"p1": {"type": "string", "description": "desc"}},
"required": ["p1"],
}
tool.get_llm_parameters_json_schema.return_value = schema
param = self.build_param(mocker, form="NOT_LLM")
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters == schema
assert result.parameters["properties"] == {}
def test_enum_and_required(self, runner, mocker: MockerFixture):
option = mocker.MagicMock(value="opt1")
param = self.build_param(
mocker,
form=module.ToolParameter.ToolParameterForm.LLM,
options=[option],
required=True,
)
tool = mocker.MagicMock()
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert "p1" in result.parameters["required"]
def test_skip_file_type_param(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock()
param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM)
param.type = module.ToolParameter.ToolParameterType.FILE
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"] == {}
def test_duplicate_required_not_duplicated(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock()
param = self.build_param(
mocker,
form=module.ToolParameter.ToolParameterForm.LLM,
required=True,
)
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": ["p1"]}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["required"].count("p1") == 1
# ==========================================================
@ -324,21 +383,57 @@ class TestConvertToolToPromptMessageTool:
def test_basic_conversion(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock(tool_name="tool1")
runtime_param = mocker.MagicMock()
runtime_param.form = module.ToolParameter.ToolParameterForm.LLM
runtime_param.name = "param1"
runtime_param.llm_description = "desc"
runtime_param.required = True
runtime_param.input_schema = None
runtime_param.options = None
mock_type = mocker.MagicMock()
mock_type.as_normal_type.return_value = "string"
runtime_param.type = mock_type
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
schema = {
"type": "object",
"properties": {"param1": {"type": "string", "description": "desc"}},
"required": ["param1"],
}
tool_entity.get_llm_parameters_json_schema.return_value = schema
tool_entity.get_merged_runtime_parameters.return_value = [runtime_param]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
assert entity == tool_entity
assert prompt_tool.parameters == schema
def test_full_conversion_multiple_params(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock(tool_name="tool1")
# LLM param with input_schema override
param1 = mocker.MagicMock()
param1.form = module.ToolParameter.ToolParameterForm.LLM
param1.name = "p1"
param1.llm_description = "desc"
param1.required = True
param1.input_schema = {"type": "integer"}
param1.options = None
param1.type = mocker.MagicMock()
# SYSTEM_FILES param should be skipped
param2 = mocker.MagicMock()
param2.form = module.ToolParameter.ToolParameterForm.LLM
param2.name = "file_param"
param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [param1, param2]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
assert entity == tool_entity
# ==========================================================
@ -370,6 +465,29 @@ class TestInitPromptToolsExtended:
class TestAdditionalCoverage:
def test_update_prompt_with_input_schema(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock()
param = mocker.MagicMock()
param.form = module.ToolParameter.ToolParameterForm.LLM
param.name = "p1"
param.required = False
param.llm_description = "desc"
param.options = None
param.input_schema = {"type": "number"}
mock_type = mocker.MagicMock()
mock_type.as_normal_type.return_value = "string"
param.type = mock_type
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"]["p1"]["type"] == "number"
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker: MockerFixture):
agent = mocker.MagicMock()
agent.tool = "tool1"
@ -453,6 +571,33 @@ class TestAdditionalCoverage:
result = runner.organize_agent_history([])
assert isinstance(result, list)
# ================= Additional Surgical Coverage =================
def test_convert_tool_select_enum_branch(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock(tool_name="tool1")
param = mocker.MagicMock()
param.form = module.ToolParameter.ToolParameterForm.LLM
param.name = "select_param"
param.required = True
param.llm_description = "desc"
param.input_schema = None
option1 = mocker.MagicMock(value="A")
option2 = mocker.MagicMock(value="B")
param.options = [option1, option2]
param.type = module.ToolParameter.ToolParameterType.SELECT
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [param]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
assert prompt_tool is not None
class TestConvertDatasetRetrieverTool:
def test_required_param_added(self, runner, mocker: MockerFixture):
@ -518,6 +663,24 @@ class TestBaseAgentRunnerInit:
class TestBaseAgentRunnerCoverage:
def test_convert_tool_skips_non_llm_param(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock(tool_name="tool1")
param = mocker.MagicMock()
param.form = "NOT_LLM"
param.type = mocker.MagicMock()
tool_entity = mocker.MagicMock()
tool_entity.entity.description.llm = "desc"
tool_entity.get_merged_runtime_parameters.return_value = [param]
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
assert prompt_tool.parameters["properties"] == {}
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker: MockerFixture):
dataset_tool = mocker.MagicMock()
dataset_tool.entity.identity.name = "ds"
@ -530,6 +693,30 @@ class TestBaseAgentRunnerCoverage:
assert tools["ds"] == dataset_tool
assert len(prompt_tools) == 1
def test_update_prompt_message_tool_select_enum(self, runner, mocker: MockerFixture):
tool = mocker.MagicMock()
option1 = mocker.MagicMock(value="A")
option2 = mocker.MagicMock(value="B")
param = mocker.MagicMock()
param.form = module.ToolParameter.ToolParameterForm.LLM
param.name = "select_param"
param.required = False
param.llm_description = "desc"
param.input_schema = None
param.options = [option1, option2]
param.type = module.ToolParameter.ToolParameterType.SELECT
tool.get_runtime_parameters.return_value = [param]
prompt_tool = mocker.MagicMock()
prompt_tool.parameters = {"properties": {}, "required": []}
result = runner.update_prompt_message_tool(tool, prompt_tool)
assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"]
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker: MockerFixture):
agent = mocker.MagicMock()
agent.tool = "tool1"

View File

@ -8,13 +8,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolEntity,
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType
class DummyCastType:
@ -31,7 +25,6 @@ class DummyParameter:
default: Any = None
options: list[Any] | None = None
llm_description: str | None = None
input_schema: dict[str, Any] | None = None
class DummyTool(Tool):
@ -156,27 +149,13 @@ def test_fork_tool_runtime_returns_new_tool_with_copied_entity():
def test_get_runtime_parameters_and_merge_runtime_parameters():
tool = _build_tool()
original = DummyParameter(
name="temperature",
type=DummyCastType(),
form="schema",
required=True,
default="0.7",
input_schema={"type": "string"},
)
original = DummyParameter(name="temperature", type=DummyCastType(), form="schema", required=True, default="0.7")
tool.entity.parameters = cast(Any, [original])
default_runtime_parameters = tool.get_runtime_parameters()
assert default_runtime_parameters == [original]
override = DummyParameter(
name="temperature",
type=DummyCastType(),
form="llm",
required=False,
default="0.5",
input_schema={"type": "object"},
)
override = DummyParameter(name="temperature", type=DummyCastType(), form="llm", required=False, default="0.5")
appended = DummyParameter(name="new_param", type=DummyCastType(), form="form", required=False, default="x")
tool.runtime_parameter_overrides = [override, appended]
@ -186,93 +165,7 @@ def test_get_runtime_parameters_and_merge_runtime_parameters():
assert merged[0].form == "llm"
assert merged[0].required is False
assert merged[0].default == "0.5"
assert merged[0].input_schema == {"type": "object"}
assert merged[1].name == "new_param"
assert merged[0] is not original
assert merged[1] is not appended
assert original.form == "schema"
assert original.required is True
assert original.default == "0.7"
assert original.input_schema == {"type": "string"}
def test_get_llm_parameters_json_schema_uses_effective_runtime_parameters():
tool = _build_tool()
query_parameter = ToolParameter.get_simple_instance(
name="query",
llm_description="Declared query",
typ=ToolParameter.ToolParameterType.STRING,
required=True,
)
region_parameter = ToolParameter.get_simple_instance(
name="region",
llm_description="Search region",
typ=ToolParameter.ToolParameterType.SELECT,
required=False,
options=["global", "cn"],
)
hidden_parameter = ToolParameter.get_simple_instance(
name="api_key",
llm_description="Hidden api key",
typ=ToolParameter.ToolParameterType.STRING,
required=True,
)
hidden_parameter.form = ToolParameter.ToolParameterForm.FORM
file_parameter = ToolParameter.get_simple_instance(
name="attachment",
llm_description="Attachment",
typ=ToolParameter.ToolParameterType.FILE,
required=False,
)
payload_parameter = ToolParameter(
name="payload",
label=I18nObject(en_US="payload", zh_Hans="payload"),
placeholder=None,
human_description=I18nObject(en_US="payload", zh_Hans="payload"),
type=ToolParameter.ToolParameterType.OBJECT,
form=ToolParameter.ToolParameterForm.LLM,
llm_description="Payload",
required=False,
input_schema={
"type": "object",
"properties": {"nested": {"type": "string"}},
},
)
tool.entity.parameters = [query_parameter, region_parameter, hidden_parameter, file_parameter, payload_parameter]
query_override = ToolParameter.get_simple_instance(
name="query",
llm_description="Runtime query",
typ=ToolParameter.ToolParameterType.STRING,
required=True,
)
tool.runtime_parameter_overrides = [query_override]
schema = tool.get_llm_parameters_json_schema()
assert schema == {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Runtime query"},
"region": {
"type": "string",
"description": "Search region",
"enum": ["global", "cn"],
},
"payload": {
"type": "object",
"properties": {"nested": {"type": "string"}},
"description": "Payload",
},
},
"required": ["query"],
}
schema["properties"]["payload"]["properties"]["nested"]["type"] = "number"
assert payload_parameter.input_schema == {
"type": "object",
"properties": {"nested": {"type": "string"}},
}
def test_message_factory_helpers():

View File

@ -2,7 +2,6 @@ from dataclasses import replace
import pytest
from clients.agent_backend import DIFY_EXECUTION_CONTEXT_LAYER_ID
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom
from core.workflow.nodes.agent_v2.runtime_request_builder import (
WorkflowAgentRuntimeBuildContext,
@ -94,10 +93,9 @@ def test_builds_create_run_request_from_agent_soul_and_node_job():
result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(_context())
dumped = result.request.model_dump(mode="json")
layers = {layer["name"]: layer for layer in dumped["composition"]["layers"]}
assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID]["config"]["agent_id"] == "agent-1"
assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID]["config"]["agent_config_version_id"] == "snapshot-1"
assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID]["config"]["invoke_from"] == "single_step"
assert dumped["execution_context"]["agent_id"] == "agent-1"
assert dumped["execution_context"]["agent_config_version_id"] == "snapshot-1"
assert dumped["execution_context"]["invoke_from"] == "single_step"
assert dumped["idempotency_key"] == "run-1:node-exec-1"
assert dumped["composition"]["layers"][0]["config"]["prefix"] == "You are careful."
assert dumped["composition"]["layers"][1]["config"]["prefix"] == "Use the previous output."
@ -147,8 +145,7 @@ def test_builds_workflow_run_request_with_file_output_schema_and_reserved_metada
result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(context)
dumped = result.request.model_dump(mode="json")
layers = {layer["name"]: layer for layer in dumped["composition"]["layers"]}
assert layers[DIFY_EXECUTION_CONTEXT_LAYER_ID]["config"]["invoke_from"] == "workflow_run"
assert dumped["execution_context"]["invoke_from"] == "workflow_run"
assert dumped["idempotency_key"] == "node-exec-1"
output_schema = dumped["composition"]["layers"][-1]["config"]["json_schema"]
assert output_schema["properties"]["report"]["properties"]["file_id"]["type"] == "string"

View File

@ -2,3 +2,5 @@
- [User guide](guide/index.md) explains how to compose layers, register config-backed
plugins, use system/user prompts, and snapshot sessions.
- [API reference](api/index.md) lists the public Agenton classes, methods, and extension
points.

View File

@ -111,10 +111,11 @@ import sys
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from dify_agent.client import Client
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DifyPluginLLMLayerConfig,
DifyPluginLayerConfig,
)
from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID, CreateRunRequest, RunComposition, RunLayerSpec
@ -146,20 +147,19 @@ def build_request() -> CreateRunRequest:
config=PromptLayerConfig(prefix=SYSTEM_PROMPT, user=USER_PROMPT),
),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(
name="plugin",
type=DIFY_PLUGIN_LAYER_TYPE_ID,
config=DifyPluginLayerConfig(
tenant_id=TENANT_ID,
plugin_id=PLUGIN_ID,
user_id=USER_ID,
invoke_from="workflow_run",
),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
deps={"plugin": "plugin"},
config=DifyPluginLLMLayerConfig(
plugin_id=PLUGIN_ID,
model_provider=MODEL_PROVIDER,
model=MODEL_NAME,
credentials=MODEL_CREDENTIALS,

View File

@ -61,11 +61,9 @@ record TTL so active runs that keep producing events remain observable.
## Scheduling and shutdown semantics
`POST /runs` persists a `running` run record and starts an `asyncio` task in the
same process. There is no Redis job stream, consumer group, pending reclaim, or
automatic retry layer. Request-shaped runtime failures such as bad composition,
prompt, output, or snapshot inputs are reported later as failed runs rather than
rejected synchronously once the request DTO itself is accepted.
`POST /runs` validates the composition, persists a `running` run record, and starts
an `asyncio` task in the same process. There is no Redis job stream, consumer
group, pending reclaim, or automatic retry layer.
During FastAPI shutdown the scheduler rejects new runs, waits up to
`DIFY_AGENT_SHUTDOWN_GRACE_SECONDS` for active tasks, then cancels remaining tasks

View File

@ -4,4 +4,5 @@ Dify Agent hosts Agenton-composed Pydantic AI runs behind a FastAPI API. Its
source code stays under `src/dify_agent`, while framework-neutral Agenton code
stays under `src/agenton` and `src/agenton_collections`.
See the [operations guide](guide/index.md) for local server behavior.
See the [operations guide](guide/index.md) for local server behavior and the
[run API](api/index.md) for request and event schemas.

View File

@ -1,67 +0,0 @@
# Execution context layer
The execution-context layer carries shared Dify run identifiers plus the tenant
and optional user context needed for plugin-daemon calls. Server settings still
provide the plugin daemon URL and API key.
Use it together with a [plugin LLM layer](../plugin-llm-layer/index.md) and,
when the caller wants Dify tools exposed to the model, a
[plugin tool layer](../plugin-tool-layer/index.md). Both business layers depend
on this layer to reach the plugin daemon.
## Config fields
| Field | Type | Meaning |
| --- | --- | --- |
| `tenant_id` | `str` | Dify tenant/workspace id used when calling the plugin daemon. |
| `user_id` | `str \| None` | Optional end-user id passed through to the plugin daemon. |
| `invoke_from` | `Literal[...]` | Dify caller category recorded for observability and correlation. |
| `app_id` / `workflow_id` / `workflow_run_id` / `node_id` / `node_execution_id` / `conversation_id` / `agent_id` / `agent_config_version_id` / `trace_id` | `str \| None` | Optional Dify-owned execution identifiers forwarded with the run. |
The execution-context layer type id is `dify.execution_context`.
## Basic usage
```python {test="skip" lint="skip"}
from dify_agent.layers.execution_context import (
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
DifyExecutionContextLayerConfig,
)
from dify_agent.protocol import RunLayerSpec
execution_context_layer = RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(
tenant_id="replace-with-tenant-id",
user_id="replace-with-user-id",
invoke_from="workflow_run",
),
)
```
If you do not need a user id, omit `user_id` or pass `None`. Most optional
execution identifiers may also be omitted when they are not available.
## Server-side settings
The execution-context layer config does not include daemon transport settings.
Configure these on the Dify Agent server instead:
```env
DIFY_AGENT_PLUGIN_DAEMON_URL=http://localhost:5002
DIFY_AGENT_PLUGIN_DAEMON_API_KEY=replace-with-plugin-daemon-server-key
```
This keeps server credentials out of client-submitted layer config and out of
session snapshots.
## Notes
- The execution-context layer does not open, cache, close, or snapshot HTTP clients.
- Concrete `plugin_id` values belong to the business layer that invokes the
daemon: the plugin LLM layer for model calls and each plugin tool config for
tool calls.
- The conventional layer name is `execution_context`. If you use another name,
point the LLM and tool layer dependencies at that name.

View File

@ -0,0 +1,59 @@
# Plugin layer
The plugin layer carries Dify plugin daemon identity for a run. It identifies the
tenant, plugin, and optional user context; server settings provide the plugin
daemon URL and API key.
Use it together with a [plugin LLM layer](../plugin-llm-layer/index.md). The LLM
layer depends on this layer to reach the plugin daemon.
## Config fields
| Field | Type | Meaning |
| --- | --- | --- |
| `tenant_id` | `str` | Dify tenant/workspace id used when calling the plugin daemon. |
| `plugin_id` | `str` | Plugin id, for example `langgenius/openai`. |
| `user_id` | `str \| None` | Optional end-user id passed through to the plugin daemon. |
The plugin layer type id is `dify.plugin`.
## Basic usage
```python {test="skip" lint="skip"}
from dify_agent.layers.dify_plugin import DIFY_PLUGIN_LAYER_TYPE_ID, DifyPluginLayerConfig
from dify_agent.protocol import RunLayerSpec
plugin_layer = RunLayerSpec(
name="plugin",
type=DIFY_PLUGIN_LAYER_TYPE_ID,
config=DifyPluginLayerConfig(
tenant_id="replace-with-tenant-id",
plugin_id="langgenius/openai",
user_id="replace-with-user-id",
),
)
```
If you do not need a user id, omit `user_id` or pass `None`.
## Server-side settings
The plugin layer config does not include daemon transport settings. Configure
these on the Dify Agent server instead:
```env
DIFY_AGENT_PLUGIN_DAEMON_URL=http://localhost:5002
DIFY_AGENT_PLUGIN_DAEMON_API_KEY=replace-with-plugin-daemon-server-key
```
This keeps server credentials out of client-submitted layer config and out of
session snapshots.
## Notes
- The plugin layer does not open, cache, close, or snapshot HTTP clients.
- `plugin_id` selects the plugin package. The business model provider and model
name belong to the plugin LLM layer, not this layer.
- The conventional layer name is `plugin`. If you use another name, point the LLM
layer dependency at that name.

View File

@ -1,18 +1,17 @@
# Plugin LLM layer
The plugin LLM layer selects the plugin package, model provider, model name,
provider credentials, and optional model settings for the current run. Dify
Agent reads the model from the reserved layer name `llm`.
The plugin LLM layer selects the model provider, model name, provider credentials,
and optional model settings for the current run. Dify Agent reads the model from
the reserved layer name `llm`.
It must depend on an [execution context layer](../execution-context-layer/index.md),
because that layer supplies the daemon identity and transport context.
It must depend on a [plugin layer](../plugin-layer/index.md), because the plugin
layer supplies the daemon identity and transport context.
## Config fields
| Field | Type | Meaning |
| --- | --- | --- |
| `plugin_id` | `str` | Plugin package id, for example `langgenius/openai`. |
| `model_provider` | `str` | Provider name inside `plugin_id`. Use the value of `DIFY_AGENT_PROVIDER` from `dify-agent/.env`. |
| `model_provider` | `str` | Provider name inside the selected plugin. Use the value of `DIFY_AGENT_PROVIDER` from `dify-agent/.env`. |
| `model` | `str` | Model name. Use the value of `DIFY_AGENT_MODEL_NAME` from `dify-agent/.env`. |
| `credentials` | `dict[str, str \| int \| float \| bool \| None]` | Provider-specific credential object. |
| `model_settings` | `ModelSettings \| None` | Optional pydantic-ai model settings. |
@ -28,14 +27,12 @@ from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID, RunLayerSpec
MODEL_PROVIDER = "replace-with-provider-from-dify-agent-env"
MODEL_NAME = "replace-with-model-from-dify-agent-env"
PLUGIN_ID = "langgenius/openai"
llm_layer = RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
deps={"plugin": "plugin"},
config=DifyPluginLLMLayerConfig(
plugin_id=PLUGIN_ID,
model_provider=MODEL_PROVIDER,
model=MODEL_NAME,
credentials={"api_key": "replace-with-provider-key"},
@ -43,30 +40,29 @@ llm_layer = RunLayerSpec(
)
```
`deps={"execution_context": "execution_context"}` means: bind the LLM layer's
dependency field named `execution_context` to the composition layer named
`execution_context`.
`deps={"plugin": "plugin"}` means: bind the LLM layer's dependency field named
`plugin` to the composition layer named `plugin`.
Set `MODEL_PROVIDER` and `MODEL_NAME` to the same values as
`DIFY_AGENT_PROVIDER` and `DIFY_AGENT_MODEL_NAME` in `dify-agent/.env`.
## Complete minimal model composition
Most runs include a prompt, execution-context layer, and LLM layer:
Most runs include a prompt, plugin context, and LLM layer:
```python {test="skip" lint="skip"}
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DifyPluginLLMLayerConfig,
DifyPluginLayerConfig,
)
from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID, RunComposition, RunLayerSpec
MODEL_PROVIDER = "replace-with-provider-from-dify-agent-env"
MODEL_NAME = "replace-with-model-from-dify-agent-env"
PLUGIN_ID = "langgenius/openai"
composition = RunComposition(
layers=[
@ -76,19 +72,18 @@ composition = RunComposition(
config=PromptLayerConfig(prefix="You are concise.", user="Say hello."),
),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(
name="plugin",
type=DIFY_PLUGIN_LAYER_TYPE_ID,
config=DifyPluginLayerConfig(
tenant_id="replace-with-tenant-id",
invoke_from="workflow_run",
plugin_id="langgenius/openai",
),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
deps={"plugin": "plugin"},
config=DifyPluginLLMLayerConfig(
plugin_id=PLUGIN_ID,
model_provider=MODEL_PROVIDER,
model=MODEL_NAME,
credentials={"api_key": "replace-with-provider-key"},
@ -101,9 +96,6 @@ composition = RunComposition(
## Notes
- The model layer must use the reserved name `llm` (`DIFY_AGENT_MODEL_LAYER_ID`).
- `plugin_id` belongs here because model calls are plugin-specific business
calls. The shared execution-context layer only carries Dify run and
tenant/user daemon context.
- Credential shape depends on the selected plugin provider; the OpenAI-style
`api_key` field above is only an example.
- Client-submitted model credentials remain in the scheduled request memory and

View File

@ -1,130 +0,0 @@
# Plugin tool layer
The plugin tool layer exposes Dify plugin tools to the model. It is designed for
Dify API to build after it has resolved a user's tool selections, plugin daemon
declarations, credentials, and manual/runtime inputs.
Unlike the plugin LLM layer, this layer may contain tools from multiple plugin
packages. Each tool config carries its own `plugin_id`, while the shared
[execution context layer](../execution-context-layer/index.md) still carries
only tenant/user daemon context.
## Responsibilities
Dify API prepares the tool config before submitting the run request:
- resolve the selected provider and tool name;
- merge declared parameters with runtime parameters;
- produce the model-visible JSON schema;
- provide hidden/manual `runtime_parameters` and credentials;
- choose the daemon `credential_type` for invocation.
Dify Agent consumes that prepared config. At run time it validates required
hidden inputs, applies defaults, casts invocation values, calls plugin daemon,
and turns tool responses into model observations.
## Config fields
The plugin tools layer type id is `dify.plugin.tools`.
`DifyPluginToolsLayerConfig` contains a list of `DifyPluginToolConfig` objects:
| Field | Type | Meaning |
| --- | --- | --- |
| `tools` | `list[DifyPluginToolConfig]` | Prepared plugin tools to expose to the model. |
Each tool config has these fields:
| Field | Type | Meaning |
| --- | --- | --- |
| `plugin_id` | `str` | Plugin package id for this tool, for example `langgenius/wikipedia`. |
| `provider` | `str` | Tool provider name inside the plugin. |
| `tool_name` | `str` | Daemon tool name to invoke. |
| `credential_type` | `"api-key" \| "oauth2" \| "unauthorized"` | Credential mode sent to plugin daemon. |
| `name` | `str \| None` | Optional model-visible tool name. Defaults to `tool_name`. |
| `description` | `str \| None` | Optional model-visible description. Defaults to the tool name. |
| `credentials` | `dict[str, str \| int \| float \| bool \| None]` | Provider-specific tool credentials. |
| `runtime_parameters` | `dict[str, JsonValue]` | Hidden/manual values merged into daemon invocation but omitted from the model schema. |
| `parameters` | `list[DifyPluginToolParameter]` | API-prepared effective parameter declarations used for validation, defaults, and casting. |
| `parameters_json_schema` | `dict[str, JsonValue]` | API-prepared JSON schema shown to the model. |
## Example: Dify API prepared Wikipedia tool
```python {test="skip" lint="skip"}
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
DifyPluginToolConfig,
DifyPluginToolParameter,
DifyPluginToolsLayerConfig,
)
from dify_agent.protocol import RunComposition, RunLayerSpec
# Dify API side: resolve the selected tool into the API-side Tool runtime first,
# for example with ToolManager.get_agent_tool_runtime(...). Then adapt its
# effective ToolParameter objects at the protocol boundary. Dify Agent accepts
# both ToolParameter attribute objects and ToolParameter.model_dump(mode="json")
# dictionaries, ignoring API-only fields such as label and human_description.
tool_runtime = ToolManager.get_agent_tool_runtime(...)
effective_parameters = tool_runtime.get_merged_runtime_parameters()
prepared_parameters = [
DifyPluginToolParameter.model_validate(parameter)
# If the API serializes first, use:
# DifyPluginToolParameter.model_validate(parameter.model_dump(mode="json"))
for parameter in effective_parameters
]
parameters_json_schema = tool_runtime.get_llm_parameters_json_schema()
composition = RunComposition(
layers=[
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(
tenant_id="replace-with-tenant-id",
user_id="replace-with-user-id",
invoke_from="workflow_run",
),
),
RunLayerSpec(
name="tools",
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
config=DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/wikipedia",
provider="wikipedia",
tool_name="wikipedia_search",
credential_type="unauthorized",
name="wikipedia_search",
description="Search Wikipedia for relevant pages.",
parameters=prepared_parameters,
runtime_parameters={"language": "en"},
parameters_json_schema=parameters_json_schema,
)
]
),
),
]
)
```
`deps={"execution_context": "execution_context"}` means: bind the tool layer's
dependency field named `execution_context` to the composition layer named
`execution_context`.
## Notes for Dify API callers
- Do not ask Dify Agent to discover tool declarations. Resolve and prepare them
in API before creating the run.
- `parameters` should include all effective parameters, including hidden/manual
ones needed for validation and default application.
- `parameters_json_schema` should include only model-visible parameters. Omit
hidden/manual parameters and file/system-file parameters unless they are truly
intended for model input.
- `runtime_parameters` should contain hidden/manual values selected by the user
or derived from workflow variables.
- Put each tool's `plugin_id` on the tool config. The shared execution-context
layer has no package-specific identity.

View File

@ -17,11 +17,12 @@ import asyncio
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from dify_agent.client import Client
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DifyPluginCredentialValue,
DifyPluginLLMLayerConfig,
DifyPluginLayerConfig,
)
from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID, CreateRunRequest, RunComposition, RunLayerSpec
@ -49,67 +50,20 @@ async def main() -> None:
),
),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(
tenant_id=TENANT_ID,
invoke_from="workflow_run",
),
name="plugin",
type=DIFY_PLUGIN_LAYER_TYPE_ID,
config=DifyPluginLayerConfig(tenant_id=TENANT_ID, plugin_id=PLUGIN_ID),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
deps={"plugin": "plugin"},
config=DifyPluginLLMLayerConfig(
plugin_id=PLUGIN_ID,
model_provider=PLUGIN_PROVIDER,
model=MODEL_NAME,
credentials=MODEL_CREDENTIALS,
),
),
# Minimal plugin-tools example. API callers should pass
# prepared parameters + JSON schema instead of relying on
# dify-agent to fetch and merge daemon declarations.
# from dify_agent.layers.dify_plugin import (
# DifyPluginToolConfig,
# DifyPluginToolParameter,
# DifyPluginToolParameterForm,
# DifyPluginToolParameterType,
# DifyPluginToolsLayerConfig,
# )
# RunLayerSpec(
# name="tools",
# type="dify.plugin.tools",
# deps={"execution_context": "execution_context"},
# config=DifyPluginToolsLayerConfig(
# tools=[
# DifyPluginToolConfig(
# plugin_id="langgenius/search",
# provider="search",
# tool_name="web_search",
# credential_type="api-key",
# credentials={"api_key": "replace-with-tool-key"},
# runtime_parameters={"site": "docs.dify.ai"},
# parameters=[
# DifyPluginToolParameter(
# name="query",
# type=DifyPluginToolParameterType.STRING,
# form=DifyPluginToolParameterForm.LLM,
# required=True,
# llm_description="Search query",
# ),
# ],
# parameters_json_schema={
# "type": "object",
# "properties": {
# "query": {"type": "string", "description": "Search query"}
# },
# "required": ["query"],
# },
# )
# ]
# ),
# ),
],
),
)

View File

@ -10,11 +10,12 @@ assuming the original request was not accepted.
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from dify_agent.client import Client
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DifyPluginCredentialValue,
DifyPluginLLMLayerConfig,
DifyPluginLayerConfig,
)
from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID, CreateRunRequest, RunComposition, RunLayerSpec
@ -42,67 +43,20 @@ def main() -> None:
),
),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(
tenant_id=TENANT_ID,
invoke_from="workflow_run",
),
name="plugin",
type=DIFY_PLUGIN_LAYER_TYPE_ID,
config=DifyPluginLayerConfig(tenant_id=TENANT_ID, plugin_id=PLUGIN_ID),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
deps={"plugin": "plugin"},
config=DifyPluginLLMLayerConfig(
plugin_id=PLUGIN_ID,
model_provider=PLUGIN_PROVIDER,
model=MODEL_NAME,
credentials=MODEL_CREDENTIALS,
),
),
# Minimal plugin-tools example. API callers should pass
# prepared parameters + JSON schema instead of relying on
# dify-agent to fetch and merge daemon declarations.
# from dify_agent.layers.dify_plugin import (
# DifyPluginToolConfig,
# DifyPluginToolParameter,
# DifyPluginToolParameterForm,
# DifyPluginToolParameterType,
# DifyPluginToolsLayerConfig,
# )
# RunLayerSpec(
# name="tools",
# type="dify.plugin.tools",
# deps={"execution_context": "execution_context"},
# config=DifyPluginToolsLayerConfig(
# tools=[
# DifyPluginToolConfig(
# plugin_id="langgenius/search",
# provider="search",
# tool_name="web_search",
# credential_type="api-key",
# credentials={"api_key": "replace-with-tool-key"},
# runtime_parameters={"site": "docs.dify.ai"},
# parameters=[
# DifyPluginToolParameter(
# name="query",
# type=DifyPluginToolParameterType.STRING,
# form=DifyPluginToolParameterForm.LLM,
# required=True,
# llm_description="Search query",
# ),
# ],
# parameters_json_schema={
# "type": "object",
# "properties": {
# "query": {"type": "string", "description": "Search query"}
# },
# "required": ["query"],
# },
# )
# ]
# ),
# ),
],
),
)

View File

@ -11,18 +11,19 @@ nav:
- Agenton:
- Overview: agenton/index.md
- Guide: agenton/guide/index.md
- API Reference: agenton/api/index.md
- Examples: agenton/examples/index.md
- Dify Agent:
- Overview: dify-agent/index.md
- User Manual:
- Get Started: dify-agent/get-started/index.md
- Prompt Layer: dify-agent/user-manual/prompt-layer/index.md
- Execution Context Layer: dify-agent/user-manual/execution-context-layer/index.md
- Plugin Layer: dify-agent/user-manual/plugin-layer/index.md
- Plugin LLM Layer: dify-agent/user-manual/plugin-llm-layer/index.md
- Plugin Tool Layer: dify-agent/user-manual/plugin-tool-layer/index.md
- History Layer: dify-agent/user-manual/history-layer/index.md
- Structured Output Layer: dify-agent/user-manual/structured-output-layer/index.md
- Operations Guide: dify-agent/guide/index.md
- Run API: dify-agent/api/index.md
- Examples: dify-agent/examples/index.md
theme:

View File

@ -8,6 +8,7 @@ this provider.
from __future__ import annotations
import json
from collections.abc import AsyncIterator, Callable, Mapping
from dataclasses import dataclass, field
from typing import NoReturn
@ -21,12 +22,6 @@ from typing_extensions import override
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UnexpectedModelBehavior, UserError
from pydantic_ai.providers import Provider
from dify_agent.plugin_daemon_transport import (
decode_plugin_daemon_error_payload,
to_plugin_daemon_jsonable,
unwrap_plugin_daemon_error,
)
_DEFAULT_DAEMON_TIMEOUT: float | httpx.Timeout | None = 600.0
@ -88,7 +83,7 @@ class DifyPluginDaemonLLMClient:
request_data: Mapping[str, object],
response_model: type[T],
) -> AsyncIterator[T]:
payload: dict[str, object] = {"data": to_plugin_daemon_jsonable(request_data)}
payload: dict[str, object] = {"data": _to_jsonable(request_data)}
if self.user_id is not None:
payload["user_id"] = self.user_id
@ -102,18 +97,14 @@ class DifyPluginDaemonLLMClient:
async with self.http_client.stream("POST", url, headers=headers, json=payload) as response:
if response.is_error:
body = (await response.aread()).decode("utf-8", errors="replace")
error = decode_plugin_daemon_error_payload(body)
error = _decode_plugin_daemon_error_payload(body)
if error is not None:
resolved_error = unwrap_plugin_daemon_error(
error_type=error["error_type"],
message=error["message"],
)
_raise_plugin_daemon_error(
model_name=model_name,
error_type=resolved_error["error_type"],
message=resolved_error["message"],
error_type=error["error_type"],
message=error["message"],
status_code=response.status_code,
body=resolved_error,
body=error,
)
raise ModelHTTPError(response.status_code, model_name, body or None)
@ -126,17 +117,13 @@ class DifyPluginDaemonLLMClient:
wrapped = PluginDaemonBasicResponse.model_validate_json(line)
if wrapped.code != 0:
error = decode_plugin_daemon_error_payload(wrapped.message)
error = _decode_plugin_daemon_error_payload(wrapped.message)
if error is not None:
resolved_error = unwrap_plugin_daemon_error(
error_type=error["error_type"],
message=error["message"],
)
_raise_plugin_daemon_error(
model_name=model_name,
error_type=resolved_error["error_type"],
message=resolved_error["message"],
body=resolved_error,
error_type=error["error_type"],
message=error["message"],
body=error,
)
raise ModelAPIError(
model_name,
@ -212,6 +199,32 @@ class DifyPluginDaemonProvider(Provider[DifyPluginDaemonLLMClient]):
return self._client
def _to_jsonable(value: object) -> object:
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
if isinstance(value, dict):
return {key: _to_jsonable(item) for key, item in value.items()}
if isinstance(value, list | tuple):
return [_to_jsonable(item) for item in value]
return value
def _decode_plugin_daemon_error_payload(raw_message: str) -> dict[str, str] | None:
try:
parsed = json.loads(raw_message)
except json.JSONDecodeError:
return None
if not isinstance(parsed, dict):
return None
error_type = parsed.get("error_type")
message = parsed.get("message")
if not isinstance(error_type, str) or not isinstance(message, str):
return None
return {"error_type": error_type, "message": message}
def _raise_plugin_daemon_error(
*,
model_name: str,
@ -223,6 +236,17 @@ def _raise_plugin_daemon_error(
http_error_body = body or {"error_type": error_type, "message": message}
match error_type:
case "PluginInvokeError":
nested_error = _decode_plugin_daemon_error_payload(message)
if nested_error is not None:
_raise_plugin_daemon_error(
model_name=model_name,
error_type=nested_error["error_type"],
message=nested_error["message"],
status_code=status_code,
body=nested_error,
)
raise ModelAPIError(model_name, message)
case "PluginDaemonUnauthorizedError" | "InvokeAuthorizationError":
raise ModelHTTPError(status_code or 401, model_name, http_error_body)
case "PluginPermissionDeniedError":

View File

@ -1,35 +1,21 @@
"""Client-safe exports for Dify plugin business-layer DTOs and type ids.
"""Client-safe exports for Dify plugin DTOs and public layer type ids.
Implementation layers live in sibling modules and require server-side runtime
dependencies. Keep this package root import-safe for client-only installs.
"""
from dify_agent.layers.dify_plugin.configs import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
DifyPluginCredentialValue,
DifyPluginLLMLayerConfig,
DifyPluginToolCredentialType,
DifyPluginToolConfig,
DifyPluginToolOption,
DifyPluginToolParameter,
DifyPluginToolParameterForm,
DifyPluginToolParameterType,
DifyPluginToolsLayerConfig,
DifyPluginToolValue,
DifyPluginLayerConfig,
)
__all__ = [
"DIFY_PLUGIN_LAYER_TYPE_ID",
"DIFY_PLUGIN_LLM_LAYER_TYPE_ID",
"DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID",
"DifyPluginCredentialValue",
"DifyPluginLLMLayerConfig",
"DifyPluginToolCredentialType",
"DifyPluginToolConfig",
"DifyPluginToolOption",
"DifyPluginToolParameter",
"DifyPluginToolParameterForm",
"DifyPluginToolParameterType",
"DifyPluginToolsLayerConfig",
"DifyPluginToolValue",
"DifyPluginLayerConfig",
]

View File

@ -1,111 +1,38 @@
"""Client-safe DTOs for Dify plugin-backed Agenton business layers.
"""Client-safe DTOs for Dify plugin-backed Agenton layers.
This module intentionally contains only public config schemas and scalar type
aliases plus stable plugin business-layer type identifiers. Runtime objects
such as HTTP clients, server settings, and adapter implementations live in
sibling implementation modules so clients can build run requests without
importing server-only dependencies.
Shared tenant/user/run context now lives in the sibling
``dify_agent.layers.execution_context`` package. This module only covers the
plugin-backed LLM and tools layers that invoke daemon features with concrete
``plugin_id`` values. Tool configs also carry the API-side prepared parameter
declarations and model-visible JSON schema so the agent runtime does not have to
re-fetch and re-merge tool declarations at execution time.
aliases plus stable layer type identifiers. Runtime objects such as HTTP
clients, server settings, and adapter implementations live in sibling
implementation modules so clients can build run requests without importing
server-only dependencies.
"""
from enum import StrEnum
from typing import ClassVar, Final, Literal, TypeAlias
from typing import ClassVar, Final, TypeAlias
from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
from pydantic import ConfigDict, Field
from pydantic_ai.settings import ModelSettings
from agenton.layers import LayerConfig
DifyPluginCredentialValue: TypeAlias = str | int | float | bool | None
DifyPluginToolCredentialType: TypeAlias = Literal["api-key", "oauth2", "unauthorized"]
DifyPluginToolValue: TypeAlias = JsonValue
DIFY_PLUGIN_LAYER_TYPE_ID: Final[str] = "dify.plugin"
DIFY_PLUGIN_LLM_LAYER_TYPE_ID: Final[str] = "dify.plugin.llm"
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID: Final[str] = "dify.plugin.tools"
class DifyPluginToolOption(BaseModel):
"""Selectable tool option value exposed to the model.
class DifyPluginLayerConfig(LayerConfig):
"""Public config for the plugin daemon tenant/plugin context layer."""
The DTO also accepts API-side option dumps and attribute objects. Fields
such as ``label`` or ``icon`` are intentionally ignored because Dify Agent
only preserves the normalized option ``value`` for tool invocation and
model-visible schema generation.
"""
tenant_id: str
plugin_id: str
user_id: str | None = None
value: str
model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore", from_attributes=True)
@field_validator("value", mode="before")
@classmethod
def stringify_value(cls, value: object) -> str:
return value if isinstance(value, str) else str(value)
class DifyPluginToolParameterType(StrEnum):
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
SELECT = "select"
SECRET_INPUT = "secret-input"
FILE = "file"
FILES = "files"
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
ANY = "any"
DYNAMIC_SELECT = "dynamic-select"
CHECKBOX = "checkbox"
SYSTEM_FILES = "system-files"
ARRAY = "array"
OBJECT = "object"
def as_normal_type(self) -> str:
if self in {
DifyPluginToolParameterType.SECRET_INPUT,
DifyPluginToolParameterType.SELECT,
DifyPluginToolParameterType.CHECKBOX,
}:
return "string"
return self.value
class DifyPluginToolParameterForm(StrEnum):
SCHEMA = "schema"
FORM = "form"
LLM = "llm"
class DifyPluginToolParameter(BaseModel):
"""Prepared tool parameter declaration supplied by the API side.
The DTO intentionally accepts both API-side ``ToolParameter`` dumps and
attribute objects so callers can adapt existing tool runtime declarations
without coupling Dify Agent to API-internal model classes.
"""
name: str
type: DifyPluginToolParameterType
form: DifyPluginToolParameterForm
required: bool = False
default: DifyPluginToolValue = None
llm_description: str | None = None
input_schema: dict[str, JsonValue] | None = None
options: list[DifyPluginToolOption] = Field(default_factory=list)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore", from_attributes=True)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
class DifyPluginLLMLayerConfig(LayerConfig):
"""Public config for selecting a plugin-backed business provider/model."""
"""Public config for selecting a business provider/model from a plugin."""
plugin_id: str
model_provider: str
model: str
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
@ -114,64 +41,10 @@ class DifyPluginLLMLayerConfig(LayerConfig):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
class DifyPluginToolConfig(LayerConfig):
"""Public config for exposing one plugin tool to the agent model.
``credential_type`` is an explicit caller-supplied daemon transport choice,
not an auto-discovered property. It must match the actual credential mode of
``credentials`` for the configured plugin tool, for example ``"api-key"``
versus ``"oauth2"``. A wrong value can make invocation fail at runtime even
when the config itself validates successfully.
``runtime_parameters`` mirrors Dify's agent-node hidden/manual tool inputs:
those values are merged into the actual daemon invocation but omitted from
the tool schema shown to the model.
``parameters`` and ``parameters_json_schema`` are API-side prepared tool
declaration artifacts. They let the agent runtime validate hidden/default
inputs and expose the correct LLM-facing schema without re-fetching or
re-merging daemon declarations at run time.
"""
plugin_id: str
provider: str
tool_name: str
credential_type: DifyPluginToolCredentialType
name: str | None = None
description: str | None = None
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
runtime_parameters: dict[str, DifyPluginToolValue] = Field(default_factory=dict)
parameters: list[DifyPluginToolParameter] = Field(default_factory=list)
parameters_json_schema: dict[str, JsonValue] = Field(
default_factory=lambda: {"type": "object", "properties": {}, "required": []}
)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
class DifyPluginToolsLayerConfig(LayerConfig):
"""Public config for the Dify plugin tools layer.
Callers configure the tools layer with this wrapper object and supply one
or more prepared ``DifyPluginToolConfig`` entries in ``tools``.
"""
tools: list[DifyPluginToolConfig] = Field(default_factory=list)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
__all__ = [
"DIFY_PLUGIN_LAYER_TYPE_ID",
"DIFY_PLUGIN_LLM_LAYER_TYPE_ID",
"DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID",
"DifyPluginCredentialValue",
"DifyPluginLLMLayerConfig",
"DifyPluginToolCredentialType",
"DifyPluginToolConfig",
"DifyPluginToolOption",
"DifyPluginToolParameter",
"DifyPluginToolParameterForm",
"DifyPluginToolParameterType",
"DifyPluginToolsLayerConfig",
"DifyPluginToolValue",
"DifyPluginLayerConfig",
]

View File

@ -1,17 +1,15 @@
"""Dify plugin LLM model layer.
This layer owns model capability resolution for Dify plugin-backed LLMs. It
depends on ``DifyExecutionContextLayer`` for shared daemon settings through
Agenton's direct dependency binding and returns a Pydantic AI model adapter
configured from the public LLM layer DTO. Runtime code supplies the FastAPI
lifespan-owned shared HTTP client to ``get_model``; the layer does not own or
discover live resources. The daemon provider carries plugin transport identity,
while the DTO's ``model_provider`` is passed to the adapter as request-level
model identity.
depends on ``DifyPluginLayer`` for daemon identity through Agenton's direct
dependency binding and returns a Pydantic AI model adapter configured from the
public LLM layer DTO. Runtime code supplies the FastAPI lifespan-owned shared
HTTP client to ``get_model``; the layer does not own or discover live resources.
The daemon provider carries plugin transport identity, while the DTO's
``model_provider`` is passed to the adapter as request-level model identity.
"""
from dataclasses import dataclass
from typing import ClassVar
import httpx
from typing_extensions import Self, override
@ -19,20 +17,20 @@ from typing_extensions import Self, override
from agenton.layers import LayerDeps, PlainLayer
from dify_agent.adapters.llm import DifyLLMAdapterModel
from dify_agent.layers.dify_plugin.configs import DIFY_PLUGIN_LLM_LAYER_TYPE_ID, DifyPluginLLMLayerConfig
from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer
from dify_agent.layers.dify_plugin.plugin_layer import DifyPluginLayer
class DifyPluginLLMDeps(LayerDeps):
"""Dependencies required by ``DifyPluginLLMLayer``."""
execution_context: DifyExecutionContextLayer # pyright: ignore[reportUninitializedInstanceVariable]
plugin: DifyPluginLayer # pyright: ignore[reportUninitializedInstanceVariable]
@dataclass(slots=True)
class DifyPluginLLMLayer(PlainLayer[DifyPluginLLMDeps, DifyPluginLLMLayerConfig]):
"""Layer that creates the Dify plugin-daemon Pydantic AI model."""
type_id: ClassVar[str] = DIFY_PLUGIN_LLM_LAYER_TYPE_ID
type_id = DIFY_PLUGIN_LLM_LAYER_TYPE_ID
config: DifyPluginLLMLayerConfig
@ -43,11 +41,8 @@ class DifyPluginLLMLayer(PlainLayer[DifyPluginLLMDeps, DifyPluginLLMLayerConfig]
return cls(config=config)
def get_model(self, *, http_client: httpx.AsyncClient) -> DifyLLMAdapterModel:
"""Return the configured model using the directly bound execution context."""
provider = self.deps.execution_context.create_daemon_provider(
plugin_id=self.config.plugin_id,
http_client=http_client,
)
"""Return the configured model using the directly bound plugin dependency."""
provider = self.deps.plugin.create_daemon_provider(http_client=http_client)
return DifyLLMAdapterModel(
model=self.config.model,
daemon_provider=provider,

View File

@ -0,0 +1,69 @@
"""Runtime Dify plugin context layer.
The public config identifies tenant/plugin/user context only. Plugin daemon URL
and API key are server-side settings injected by the provider factory. The layer
is intentionally config/settings-only under Agenton's state-only core: it does
not open, cache, close, or snapshot HTTP clients, and its lifecycle hooks remain
the inherited no-op hooks. Runtime code passes the FastAPI lifespan-owned shared
``httpx.AsyncClient`` into ``create_daemon_provider`` for each model adapter.
Business model-provider names belong to the LLM layer/model request, not this
daemon context layer.
"""
from dataclasses import dataclass
import httpx
from typing_extensions import Self, override
from agenton.layers import EmptyRuntimeState, NoLayerDeps, PlainLayer
from dify_agent.adapters.llm import DifyPluginDaemonProvider
from dify_agent.layers.dify_plugin.configs import DIFY_PLUGIN_LAYER_TYPE_ID, DifyPluginLayerConfig
@dataclass(slots=True)
class DifyPluginLayer(PlainLayer[NoLayerDeps, DifyPluginLayerConfig, EmptyRuntimeState]):
"""Layer that carries plugin daemon identity without owning live resources."""
type_id = DIFY_PLUGIN_LAYER_TYPE_ID
config: DifyPluginLayerConfig
daemon_url: str
daemon_api_key: str
@classmethod
@override
def from_config(cls, config: DifyPluginLayerConfig) -> Self:
"""Reject construction without server-injected daemon settings."""
del config
raise TypeError("DifyPluginLayer requires server-side daemon settings and must use a provider factory.")
@classmethod
def from_config_with_settings(
cls,
config: DifyPluginLayerConfig,
*,
daemon_url: str,
daemon_api_key: str,
) -> Self:
"""Create a plugin layer from public config plus server-only daemon settings."""
return cls(config=config, daemon_url=daemon_url, daemon_api_key=daemon_api_key)
def create_daemon_provider(self, *, http_client: httpx.AsyncClient) -> DifyPluginDaemonProvider:
"""Return a daemon provider backed by the shared plugin daemon client.
Raises:
RuntimeError: if ``http_client`` has already been closed.
"""
if http_client.is_closed:
raise RuntimeError("DifyPluginLayer.create_daemon_provider() requires an open shared HTTP client.")
return DifyPluginDaemonProvider(
tenant_id=self.config.tenant_id,
plugin_id=self.config.plugin_id,
plugin_daemon_url=self.daemon_url,
plugin_daemon_api_key=self.daemon_api_key,
user_id=self.config.user_id,
http_client=http_client,
)
__all__ = ["DifyPluginLayer"]

View File

@ -1,333 +0,0 @@
"""Async plugin-daemon client for Dify plugin tool invocation.
The agent runtime talks to the plugin daemon rather than importing provider SDKs
directly. The tools layer now consumes API-prepared declarations from config, so
this module only keeps the invoke-time boundary:
- POST ``/plugin/{tenant_id}/dispatch/tool/invoke``
- request headers ``X-Api-Key``, ``X-Plugin-ID``, and ``Content-Type``
- top-level ``user_id`` forwarding when shared execution context includes one
- stream decoding and blob-chunk merging for agent observations
The shared execution-context layer still owns tenant/user daemon context, while
each tool's own ``plugin_id`` determines the transport identity placed in
``X-Plugin-ID``.
"""
from __future__ import annotations
import base64
from collections.abc import AsyncIterator, Mapping
from dataclasses import dataclass, field
from enum import StrEnum
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
from dify_agent.layers.dify_plugin.configs import DifyPluginToolCredentialType
from dify_agent.plugin_daemon_transport import (
decode_plugin_daemon_error_payload,
to_plugin_daemon_jsonable,
unwrap_plugin_daemon_error,
)
class PluginDaemonBasicResponse(BaseModel):
"""Common plugin-daemon stream and JSON wrapper."""
code: int
message: str
data: object | None = None
@dataclass(slots=True)
class FileChunk:
"""Buffer for accumulating streamed blob chunks."""
total_length: int
bytes_written: int = field(default=0, init=False)
data: bytearray = field(init=False)
def __post_init__(self) -> None:
self.data = bytearray(self.total_length)
class DifyPluginToolInvokeMessage(BaseModel):
"""Subset of Dify tool stream messages needed for agent observations."""
class TextMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
json_object: dict[str, object] | list[object]
suppress_output: bool = False
class BlobMessage(BaseModel):
blob: bytes
class BlobChunkMessage(BaseModel):
id: str
sequence: int
total_length: int
blob: bytes
end: bool
class FileMessage(BaseModel):
file_marker: str = "file_marker"
@model_validator(mode="before")
@classmethod
def validate_file_marker(cls, values: object) -> object:
if isinstance(values, dict) and "file_marker" not in values:
raise ValueError("Invalid FileMessage: missing file_marker")
return values
class VariableMessage(BaseModel):
variable_name: str
variable_value: object
stream: bool = False
class LogMessage(BaseModel):
id: str
label: str
parent_id: str | None = None
error: str | None = None
status: str
data: Mapping[str, object] = Field(default_factory=dict)
metadata: Mapping[str, object] = Field(default_factory=dict)
class MessageType(StrEnum):
TEXT = "text"
IMAGE = "image"
LINK = "link"
BLOB = "blob"
JSON = "json"
IMAGE_LINK = "image_link"
BINARY_LINK = "binary_link"
VARIABLE = "variable"
FILE = "file"
LOG = "log"
BLOB_CHUNK = "blob_chunk"
type: MessageType = MessageType.TEXT
message: (
TextMessage | JsonMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | VariableMessage | None
)
meta: dict[str, object] | None = None
@field_validator("message", mode="before")
@classmethod
def decode_message(cls, value: object, info: ValidationInfo) -> object:
if isinstance(value, dict) and "blob" in value:
try:
value = {**value, "blob": base64.b64decode(value["blob"])}
except Exception:
return value
msg_type = info.data.get("type") if isinstance(info.data, dict) else None
if msg_type == cls.MessageType.JSON and isinstance(value, dict) and "json_object" not in value:
return {"json_object": value}
if msg_type == cls.MessageType.FILE and isinstance(value, dict):
return {"file_marker": value.get("file_marker", "file_marker")}
return value
class DifyPluginToolClientError(Exception):
"""Raised when the plugin daemon rejects a tool-layer request."""
error_type: str | None
status_code: int | None
def __init__(self, message: str, *, error_type: str | None = None, status_code: int | None = None) -> None:
super().__init__(message)
self.error_type = error_type
self.status_code = status_code
@dataclass(slots=True)
class DifyPluginDaemonToolClient:
"""HTTP wrapper for the invoke-only plugin-daemon tool boundary.
Callers provide business-level provider/tool/credential data per invocation,
while this client supplies daemon transport identity from shared runtime
context: tenant path segment, daemon API key, plugin-specific ``X-Plugin-ID``
header, and optional top-level ``user_id``.
"""
plugin_daemon_url: str
plugin_daemon_api_key: str
tenant_id: str
plugin_id: str
user_id: str | None
http_client: httpx.AsyncClient = field(repr=False)
def __post_init__(self) -> None:
self.plugin_daemon_url = self.plugin_daemon_url.rstrip("/")
async def invoke(
self,
*,
provider: str,
tool_name: str,
credential_type: DifyPluginToolCredentialType,
credentials: dict[str, object],
tool_parameters: Mapping[str, object],
) -> list[DifyPluginToolInvokeMessage]:
"""Invoke a plugin tool and collect its observation stream."""
raw_messages = [
item
async for item in self._iter_stream_response(
path=f"plugin/{self.tenant_id}/dispatch/tool/invoke",
request_data={
"provider": provider,
"tool": tool_name,
"credentials": credentials,
"credential_type": credential_type,
"tool_parameters": dict(tool_parameters),
},
response_model=DifyPluginToolInvokeMessage,
)
]
return merge_blob_chunks(raw_messages)
async def _iter_stream_response[T: BaseModel](
self,
*,
path: str,
request_data: Mapping[str, object],
response_model: type[T],
) -> AsyncIterator[T]:
"""Send one daemon stream request and yield typed items.
The daemon expects the actual invoke payload nested under ``data``. When
the shared plugin context included ``user_id``, it is forwarded as a
top-level peer to ``data`` so daemon-side auditing and credential logic
can attribute the request to the end user.
"""
payload: dict[str, object] = {"data": to_plugin_daemon_jsonable(dict(request_data))}
if self.user_id is not None:
payload["user_id"] = self.user_id
url = f"{self.plugin_daemon_url}/{path}"
async with self.http_client.stream("POST", url, headers=self._headers(), json=payload) as response:
if response.is_error:
body = (await response.aread()).decode("utf-8", errors="replace")
error = decode_plugin_daemon_error_payload(body)
if error is not None:
resolved_error = unwrap_plugin_daemon_error(
error_type=error["error_type"],
message=error["message"],
)
_raise_tool_daemon_error(
error_type=resolved_error["error_type"],
message=resolved_error["message"],
status_code=response.status_code,
)
raise DifyPluginToolClientError(
body or "Plugin daemon stream request failed.", status_code=response.status_code
)
async for raw_line in response.aiter_lines():
line = raw_line.strip()
if not line:
continue
if line.startswith("data:"):
line = line[5:].strip()
wrapped = PluginDaemonBasicResponse.model_validate_json(line)
if wrapped.code != 0:
error = decode_plugin_daemon_error_payload(wrapped.message)
if error is not None:
resolved_error = unwrap_plugin_daemon_error(
error_type=error["error_type"],
message=error["message"],
)
_raise_tool_daemon_error(
error_type=resolved_error["error_type"],
message=resolved_error["message"],
)
raise DifyPluginToolClientError(wrapped.message or "Plugin daemon returned an error stream item.")
if wrapped.data is None:
raise DifyPluginToolClientError("Plugin daemon returned an empty stream item.")
yield response_model.model_validate(wrapped.data)
def _headers(self) -> dict[str, str]:
"""Build required plugin-daemon transport headers for tool invocation."""
return {
"X-Api-Key": self.plugin_daemon_api_key,
"X-Plugin-ID": self.plugin_id,
"Content-Type": "application/json",
}
def merge_blob_chunks(
response: list[DifyPluginToolInvokeMessage],
*,
max_file_size: int = 30 * 1024 * 1024,
max_chunk_size: int = 8192,
) -> list[DifyPluginToolInvokeMessage]:
"""Merge streamed blob chunks into complete blob messages.
This mirrors Dify API's plugin-daemon chunk-merging behavior before the
higher-level observation conversion logic sees tool stream messages.
"""
files: dict[str, FileChunk] = {}
merged_messages: list[DifyPluginToolInvokeMessage] = []
for resp in response:
if resp.type is DifyPluginToolInvokeMessage.MessageType.BLOB_CHUNK:
if not isinstance(resp.message, DifyPluginToolInvokeMessage.BlobChunkMessage):
raise TypeError("Blob chunk responses must carry BlobChunkMessage payloads.")
chunk_id = resp.message.id
total_length = resp.message.total_length
blob_data = resp.message.blob
is_end = resp.message.end
if chunk_id not in files:
files[chunk_id] = FileChunk(total_length)
if files[chunk_id].bytes_written + len(blob_data) > max_file_size:
del files[chunk_id]
raise ValueError(f"File is too large which reached the limit of {max_file_size / 1024 / 1024}MB")
if len(blob_data) > max_chunk_size:
raise ValueError(f"File chunk is too large which reached the limit of {max_chunk_size / 1024}KB")
files[chunk_id].data[files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)] = (
blob_data
)
files[chunk_id].bytes_written += len(blob_data)
if is_end:
merged_messages.append(
DifyPluginToolInvokeMessage(
type=DifyPluginToolInvokeMessage.MessageType.BLOB,
message=DifyPluginToolInvokeMessage.BlobMessage(
blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written])
),
meta=resp.meta,
)
)
del files[chunk_id]
else:
merged_messages.append(resp)
return merged_messages
def _raise_tool_daemon_error(
*,
error_type: str,
message: str,
status_code: int | None = None,
) -> None:
raise DifyPluginToolClientError(message, error_type=error_type, status_code=status_code)
__all__ = [
"DifyPluginDaemonToolClient",
"DifyPluginToolClientError",
"DifyPluginToolCredentialType",
"DifyPluginToolInvokeMessage",
"merge_blob_chunks",
]

View File

@ -1,341 +0,0 @@
"""Dify plugin tools layer for agent-accessible plugin tools.
This layer consumes API-prepared plugin tool declarations. The API side is
responsible for resolving daemon declarations, applying runtime-parameter
overrides, and producing the clean LLM-facing JSON schema. At run time the layer
only validates hidden/manual inputs, prepares invocation arguments, and maps
daemon responses into agent-friendly observations.
Like the LLM layer, this layer never owns live HTTP clients. The runtime passes
the FastAPI lifespan-owned shared client into ``get_tools`` so the layer can
build Pydantic AI tool adapters on demand.
"""
from __future__ import annotations
from copy import deepcopy
import json
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import ClassVar
import httpx
from pydantic_ai import RunContext, Tool
from pydantic_ai.tools import ToolDefinition
from typing_extensions import Self, override
from agenton.layers import LayerDeps, PlainLayer
from dify_agent.layers.dify_plugin.configs import (
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
DifyPluginToolConfig,
DifyPluginToolParameter,
DifyPluginToolParameterForm,
DifyPluginToolParameterType,
DifyPluginToolsLayerConfig,
)
from dify_agent.layers.dify_plugin.tool_client import (
DifyPluginDaemonToolClient,
DifyPluginToolClientError,
DifyPluginToolInvokeMessage,
)
from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer
# Plugin tools intentionally do not expose a per-tool strictness override in the
# public config. The API supplies already-prepared schemas, but Dify Agent always
# registers those tools in loose mode so daemon tool invocation stays tolerant of
# plugin schema differences and older API-prepared payloads.
PLUGIN_TOOL_STRICT = False
class DifyPluginToolsDeps(LayerDeps):
"""Dependencies required by ``DifyPluginToolsLayer``."""
execution_context: DifyExecutionContextLayer # pyright: ignore[reportUninitializedInstanceVariable]
@dataclass(slots=True)
class DifyPluginToolsLayer(PlainLayer[DifyPluginToolsDeps, DifyPluginToolsLayerConfig]):
"""Layer that resolves Dify plugin tools into Pydantic AI tools."""
type_id: ClassVar[str] = DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID
config: DifyPluginToolsLayerConfig
@classmethod
@override
def from_config(cls, config: DifyPluginToolsLayerConfig) -> Self:
"""Create the tools layer from validated public config."""
return cls(config=DifyPluginToolsLayerConfig.model_validate(config))
async def get_tools(self, *, http_client: httpx.AsyncClient) -> list[Tool[object]]:
"""Build Pydantic AI tool adapters from prepared plugin tool config."""
tool_clients: dict[str, DifyPluginDaemonToolClient] = {}
tools: list[Tool[object]] = []
for tool_config in self.config.tools:
client = tool_clients.get(tool_config.plugin_id)
if client is None:
client = self.deps.execution_context.create_tool_client(
plugin_id=tool_config.plugin_id,
http_client=http_client,
)
tool_clients[tool_config.plugin_id] = client
effective_parameters = [parameter.model_copy(deep=True) for parameter in tool_config.parameters]
_validate_required_hidden_parameters(tool_config, effective_parameters)
tools.append(
_build_pydantic_ai_tool(
client=client,
tool_config=tool_config,
effective_parameters=effective_parameters,
)
)
return tools
def _validate_required_hidden_parameters(
tool_config: DifyPluginToolConfig,
effective_parameters: Sequence[DifyPluginToolParameter],
) -> None:
missing_names = [
parameter.name
for parameter in effective_parameters
if parameter.form is not DifyPluginToolParameterForm.LLM
and parameter.required
and parameter.default is None
and parameter.name not in tool_config.runtime_parameters
]
if missing_names:
names = ", ".join(sorted(missing_names))
raise ValueError(f"Tool '{tool_config.tool_name}' requires non-LLM runtime_parameters for: {names}.")
def _build_pydantic_ai_tool(
*,
client: DifyPluginDaemonToolClient,
tool_config: DifyPluginToolConfig,
effective_parameters: Sequence[DifyPluginToolParameter],
) -> Tool[object]:
tool_name = tool_config.name or tool_config.tool_name
tool_description = tool_config.description or tool_name
tool_schema = deepcopy(tool_config.parameters_json_schema)
async def invoke_tool(_ctx: RunContext[object], **tool_arguments: object) -> str:
try:
merged_arguments = _prepare_tool_arguments(effective_parameters, tool_config, tool_arguments)
messages = await client.invoke(
provider=tool_config.provider,
tool_name=tool_config.tool_name,
credential_type=tool_config.credential_type,
credentials=dict(tool_config.credentials),
tool_parameters=merged_arguments,
)
return _convert_tool_response_to_text(messages)
except DifyPluginToolClientError as exc:
return _tool_error_text(tool_name=tool_name, error=exc)
except ValueError as exc:
return f"tool parameters validation error: {exc}, please check your tool parameters"
async def prepare_tool_definition(_ctx: RunContext[object], tool_def: ToolDefinition) -> ToolDefinition:
return ToolDefinition(
name=tool_def.name,
description=tool_def.description,
parameters_json_schema=tool_schema,
strict=PLUGIN_TOOL_STRICT,
sequential=tool_def.sequential,
metadata=tool_def.metadata,
timeout=tool_def.timeout,
defer_loading=tool_def.defer_loading,
kind=tool_def.kind,
return_schema=tool_def.return_schema,
include_return_schema=tool_def.include_return_schema,
)
return Tool(
invoke_tool,
takes_ctx=True,
name=tool_name,
description=tool_description,
prepare=prepare_tool_definition,
)
def _prepare_tool_arguments(
effective_parameters: Sequence[DifyPluginToolParameter],
tool_config: DifyPluginToolConfig,
tool_arguments: Mapping[str, object],
) -> dict[str, object]:
"""Build the daemon invocation payload from prepared config + model args.
Argument precedence intentionally mirrors the old Dify tool runtime contract:
1. start from config-supplied ``runtime_parameters`` for hidden/manual inputs;
2. let model-supplied tool arguments override same-named entries;
3. if neither provided a value, fall back to the prepared parameter default;
4. if a required parameter still has no value, raise validation error.
Only parameters declared in ``effective_parameters`` are type-cast here;
extra merged keys are passed through unchanged for forward compatibility with
prepared config that may contain additional daemon inputs.
"""
merged_arguments: dict[str, object] = dict(tool_config.runtime_parameters)
merged_arguments.update(tool_arguments)
prepared_arguments: dict[str, object] = {}
for parameter in effective_parameters:
if parameter.name in merged_arguments:
value = merged_arguments[parameter.name]
elif parameter.default is not None:
value = parameter.default
elif parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
else:
continue
prepared_arguments[parameter.name] = _cast_tool_parameter_value(parameter.type, value)
for key, value in merged_arguments.items():
prepared_arguments.setdefault(key, value)
return prepared_arguments
def _cast_tool_parameter_value(parameter_type: DifyPluginToolParameterType, value: object) -> object:
"""Cast prepared tool argument values into daemon-facing wire shapes.
The API side prepares declaration metadata, but the actual invocation payload
still needs to match Dify plugin-daemon expectations. This helper keeps the
runtime-side coercion rules for common scalar, collection, file, and selector
parameter types so model-supplied JSON values and config-supplied hidden
inputs are normalized before transport.
"""
match parameter_type:
case (
DifyPluginToolParameterType.STRING
| DifyPluginToolParameterType.SECRET_INPUT
| DifyPluginToolParameterType.SELECT
| DifyPluginToolParameterType.CHECKBOX
| DifyPluginToolParameterType.DYNAMIC_SELECT
):
return "" if value is None else value if isinstance(value, str) else str(value)
case DifyPluginToolParameterType.BOOLEAN:
if value is None:
return False
if isinstance(value, str):
lowered = value.lower()
if lowered in {"true", "yes", "y", "1"}:
return True
if lowered in {"false", "no", "n", "0"}:
return False
return value if isinstance(value, bool) else bool(value)
case DifyPluginToolParameterType.NUMBER:
if isinstance(value, int | float):
return value
if isinstance(value, str) and value:
return float(value) if "." in value else int(value)
return value
case DifyPluginToolParameterType.SYSTEM_FILES | DifyPluginToolParameterType.FILES:
return value if isinstance(value, list) else [value]
case DifyPluginToolParameterType.FILE:
if isinstance(value, list):
if len(value) != 1:
raise ValueError("This parameter only accepts one file but got multiple files while invoking.")
return value[0]
return value
case DifyPluginToolParameterType.MODEL_SELECTOR | DifyPluginToolParameterType.APP_SELECTOR:
if not isinstance(value, dict):
raise ValueError("The selector must be a dictionary.")
return value
case DifyPluginToolParameterType.ANY:
if value is not None and not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("The var selector must be a string, dictionary, list or number.")
return value
case DifyPluginToolParameterType.ARRAY:
if isinstance(value, list):
return value
if isinstance(value, str):
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
return [value]
if isinstance(parsed_value, list):
return parsed_value
return [value]
case DifyPluginToolParameterType.OBJECT:
if isinstance(value, dict):
return value
if isinstance(value, str):
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
return {}
if isinstance(parsed_value, dict):
return parsed_value
return {}
raise AssertionError(f"Unsupported tool parameter type: {parameter_type}")
def _tool_error_text(*, tool_name: str, error: DifyPluginToolClientError) -> str:
"""Map expected daemon/tool failures into agent-visible observation text.
Only known plugin-daemon rejection categories should be softened into tool
observations. Unexpected local bugs are intentionally not handled here and
should propagate so tests and callers notice the regression.
"""
error_type = error.error_type or ""
if any(token in error_type for token in ("Credential", "Authorization", "Unauthorized")):
return "Please check your tool provider credentials"
if any(token in error_type for token in ("ToolNotFound", "ProviderNotFound")):
return f"there is not a tool named {tool_name}"
if error.status_code == 400 or any(token in error_type for token in ("BadRequest", "Validate", "Validation")):
return f"tool parameters validation error: {error}, please check your tool parameters"
return f"tool invoke error: {error}"
def _convert_tool_response_to_text(tool_response: Sequence[DifyPluginToolInvokeMessage]) -> str:
"""Convert daemon stream messages into the plain-text tool observation.
This preserves the user-facing semantics Dify's agent tool runtime relies on:
text is appended directly, links/images become user-check instructions, JSON
output is included unless explicitly suppressed, variable messages stay
internal, and everything else falls back to ``str(message)``. JSON fragments
are deduplicated against existing text so mixed text/JSON streams do not
repeat the same content unnecessarily.
"""
parts: list[str] = []
json_parts: list[str] = []
for response in tool_response:
if response.type is DifyPluginToolInvokeMessage.MessageType.TEXT:
text_message = response.message
if isinstance(text_message, DifyPluginToolInvokeMessage.TextMessage):
parts.append(text_message.text)
elif response.type is DifyPluginToolInvokeMessage.MessageType.LINK:
link_message = response.message
if isinstance(link_message, DifyPluginToolInvokeMessage.TextMessage):
parts.append(f"result link: {link_message.text}. please tell user to check it.")
elif response.type in {
DifyPluginToolInvokeMessage.MessageType.IMAGE_LINK,
DifyPluginToolInvokeMessage.MessageType.IMAGE,
}:
parts.append(
"image has been created and sent to user already, "
"you do not need to create it, just tell the user to check it now."
)
elif response.type is DifyPluginToolInvokeMessage.MessageType.JSON:
json_message = response.message
if isinstance(json_message, DifyPluginToolInvokeMessage.JsonMessage) and not json_message.suppress_output:
json_parts.append(json.dumps(json_message.json_object, ensure_ascii=False, default=str))
elif response.type is DifyPluginToolInvokeMessage.MessageType.VARIABLE:
continue
else:
parts.append(str(response.message))
if json_parts:
existing_parts = set(parts)
parts.extend(part for part in json_parts if part not in existing_parts)
return "".join(parts)
__all__ = ["DifyPluginToolsDeps", "DifyPluginToolsLayer"]

View File

@ -1,18 +0,0 @@
"""Client-safe exports for the Dify execution-context layer DTOs.
Implementation layers live in sibling modules and require server-side runtime
dependencies. Keep this package root import-safe for client code that only
needs to build run requests.
"""
from dify_agent.layers.execution_context.configs import (
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
DifyExecutionContextInvokeFrom,
DifyExecutionContextLayerConfig,
)
__all__ = [
"DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID",
"DifyExecutionContextInvokeFrom",
"DifyExecutionContextLayerConfig",
]

View File

@ -1,50 +0,0 @@
"""Client-safe DTOs for the Dify execution-context Agenton layer.
This layer carries Dify-owned execution identifiers plus the tenant/user daemon
transport context shared by plugin-backed business layers. The identifiers are
for observability and product correlation only; callers must not treat them as
authorization proof. Server-only plugin-daemon settings are injected by the
runtime provider factory and therefore do not appear in this public schema.
"""
from typing import ClassVar, Final, Literal, TypeAlias
from pydantic import ConfigDict
from agenton.layers import LayerConfig
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID: Final[str] = "dify.execution_context"
DifyExecutionContextInvokeFrom: TypeAlias = Literal[
"workflow_run",
"single_step",
"agent_app",
"babysit",
"fasten",
]
class DifyExecutionContextLayerConfig(LayerConfig):
"""Public config for Dify execution identity and daemon transport context."""
tenant_id: str
user_id: str | None = None
app_id: str | None = None
workflow_id: str | None = None
workflow_run_id: str | None = None
node_id: str | None = None
node_execution_id: str | None = None
conversation_id: str | None = None
agent_id: str | None = None
agent_config_version_id: str | None = None
invoke_from: DifyExecutionContextInvokeFrom
trace_id: str | None = None
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
__all__ = [
"DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID",
"DifyExecutionContextInvokeFrom",
"DifyExecutionContextLayerConfig",
]

View File

@ -1,95 +0,0 @@
"""Runtime Dify execution-context layer.
The public config carries Dify-owned execution identifiers plus the tenant/user
daemon context needed by plugin-backed business layers. Server-only daemon URL
and API key are injected by the provider factory. The layer is intentionally
config/settings-only under Agenton's state-only core: it does not open, cache,
close, or snapshot HTTP clients, and its lifecycle hooks remain the inherited
no-op hooks. Runtime code passes the FastAPI lifespan-owned shared
``httpx.AsyncClient`` into ``create_daemon_provider`` or ``create_tool_client``
for each invocation.
"""
from dataclasses import dataclass
from typing import ClassVar
import httpx
from typing_extensions import Self, override
from agenton.layers import EmptyRuntimeState, NoLayerDeps, PlainLayer
from dify_agent.adapters.llm import DifyPluginDaemonProvider
from dify_agent.layers.dify_plugin.tool_client import DifyPluginDaemonToolClient
from dify_agent.layers.execution_context.configs import (
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
DifyExecutionContextLayerConfig,
)
@dataclass(slots=True)
class DifyExecutionContextLayer(PlainLayer[NoLayerDeps, DifyExecutionContextLayerConfig, EmptyRuntimeState]):
"""Layer that carries Dify execution context without owning live resources."""
type_id: ClassVar[str] = DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID
config: DifyExecutionContextLayerConfig
daemon_url: str
daemon_api_key: str
@classmethod
@override
def from_config(cls, config: DifyExecutionContextLayerConfig) -> Self:
"""Reject construction without server-injected daemon settings."""
del config
raise TypeError(
"DifyExecutionContextLayer requires server-side daemon settings and must use a provider factory."
)
@classmethod
def from_config_with_settings(
cls,
config: DifyExecutionContextLayerConfig,
*,
daemon_url: str,
daemon_api_key: str,
) -> Self:
"""Create the layer from public config plus server-only daemon settings."""
return cls(config=config, daemon_url=daemon_url, daemon_api_key=daemon_api_key)
def create_daemon_provider(self, *, plugin_id: str, http_client: httpx.AsyncClient) -> DifyPluginDaemonProvider:
"""Return a daemon provider backed by the shared plugin daemon client.
Raises:
RuntimeError: if ``http_client`` has already been closed.
"""
if http_client.is_closed:
raise RuntimeError(
"DifyExecutionContextLayer.create_daemon_provider() requires an open shared HTTP client."
)
return DifyPluginDaemonProvider(
tenant_id=self.config.tenant_id,
plugin_id=plugin_id,
plugin_daemon_url=self.daemon_url,
plugin_daemon_api_key=self.daemon_api_key,
user_id=self.config.user_id,
http_client=http_client,
)
def create_tool_client(self, *, plugin_id: str, http_client: httpx.AsyncClient) -> DifyPluginDaemonToolClient:
"""Return a plugin-daemon tool client backed by the shared HTTP client.
Raises:
RuntimeError: if ``http_client`` has already been closed.
"""
if http_client.is_closed:
raise RuntimeError("DifyExecutionContextLayer.create_tool_client() requires an open shared HTTP client.")
return DifyPluginDaemonToolClient(
tenant_id=self.config.tenant_id,
plugin_id=plugin_id,
plugin_daemon_url=self.daemon_url,
plugin_daemon_api_key=self.daemon_api_key,
user_id=self.config.user_id,
http_client=http_client,
)
__all__ = ["DifyExecutionContextLayer"]

View File

@ -1,72 +0,0 @@
"""Shared plugin-daemon transport helpers.
These helpers define the common request-payload and nested-error semantics used
by Dify Agent's LLM and tools daemon clients so the two transport adapters do
not drift when the daemon protocol evolves.
"""
from __future__ import annotations
import json
from typing import TypedDict
from pydantic import BaseModel
class PluginDaemonErrorPayload(TypedDict):
"""Decoded plugin-daemon error payload."""
error_type: str
message: str
def to_plugin_daemon_jsonable(value: object) -> object:
"""Convert nested request data into JSON-safe daemon payload values."""
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
if isinstance(value, dict):
return {key: to_plugin_daemon_jsonable(item) for key, item in value.items()}
if isinstance(value, list | tuple):
return [to_plugin_daemon_jsonable(item) for item in value]
return value
def decode_plugin_daemon_error_payload(raw_message: str) -> PluginDaemonErrorPayload | None:
"""Decode one plugin-daemon JSON error payload if present."""
try:
parsed = json.loads(raw_message)
except json.JSONDecodeError:
return None
if not isinstance(parsed, dict):
return None
error_type = parsed.get("error_type")
message = parsed.get("message")
if not isinstance(error_type, str) or not isinstance(message, str):
return None
return {"error_type": error_type, "message": message}
def unwrap_plugin_daemon_error(
*,
error_type: str,
message: str,
) -> PluginDaemonErrorPayload:
"""Unwrap nested ``PluginInvokeError`` payloads to their effective error."""
if error_type == "PluginInvokeError":
nested_error = decode_plugin_daemon_error_payload(message)
if nested_error is not None:
return unwrap_plugin_daemon_error(
error_type=nested_error["error_type"],
message=nested_error["message"],
)
return {"error_type": error_type, "message": message}
__all__ = [
"PluginDaemonErrorPayload",
"decode_plugin_daemon_error_payload",
"to_plugin_daemon_jsonable",
"unwrap_plugin_daemon_error",
]

View File

@ -11,6 +11,8 @@ from .schemas import (
CreateRunRequest,
CreateRunResponse,
EmptyRunEventData,
ExecutionContext,
InvokeFrom,
LayerExitSignals,
PydanticAIStreamRunEvent,
RunCancelledEvent,
@ -44,6 +46,8 @@ __all__ = [
"DIFY_AGENT_MODEL_LAYER_ID",
"DIFY_AGENT_OUTPUT_LAYER_ID",
"EmptyRunEventData",
"ExecutionContext",
"InvokeFrom",
"LayerExitSignals",
"PydanticAIStreamRunEvent",
"RUN_EVENT_ADAPTER",

View File

@ -47,6 +47,7 @@ DIFY_AGENT_HISTORY_LAYER_ID: Final[str] = "history"
DIFY_AGENT_OUTPUT_LAYER_ID: Final[str] = "output"
RunStatus = Literal["running", "paused", "succeeded", "failed", "cancelled"]
RunPurpose = Literal["workflow_node", "single_step", "agent_app", "babysit", "fasten_preview"]
InvokeFrom = Literal["workflow_run", "single_step", "agent_app", "babysit", "fasten"]
RunEventType = Literal[
"run_started",
"pydantic_ai_event",
@ -105,6 +106,29 @@ class RunComposition(BaseModel):
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
class ExecutionContext(BaseModel):
"""Dify-owned execution identifiers attached to one Agent backend run.
The Agent backend stores and replays this context for observability and
product correlation only. It must not use these identifiers as authorization
proof; API backend remains responsible for tenant and user access checks.
"""
tenant_id: str
app_id: str | None = None
workflow_id: str | None = None
workflow_run_id: str | None = None
node_id: str | None = None
node_execution_id: str | None = None
conversation_id: str | None = None
agent_id: str | None = None
agent_config_version_id: str | None = None
invoke_from: InvokeFrom
trace_id: str | None = None
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
class CreateRunRequest(BaseModel):
"""Request body for creating one async agent run.
@ -118,13 +142,11 @@ class CreateRunRequest(BaseModel):
explicitly request delete for one or more layers. Session snapshots do not
preserve output-layer config, so resume requests that rely on structured
output must include the same ``output`` layer in ``composition.layers[]`` to
keep snapshot compatibility and rebuild the output schema. Dify tenant,
user, and run-correlation identifiers must be submitted through a
``dify.execution_context`` entry in ``composition.layers[]``; there is no
parallel top-level ``execution_context`` request field.
keep snapshot compatibility and rebuild the output schema.
"""
composition: RunComposition
execution_context: ExecutionContext | None = None
purpose: RunPurpose = "workflow_node"
idempotency_key: str | None = None
metadata: dict[str, JsonValue] = Field(default_factory=dict)
@ -334,6 +356,8 @@ __all__ = [
"DIFY_AGENT_MODEL_LAYER_ID",
"DIFY_AGENT_OUTPUT_LAYER_ID",
"EmptyRunEventData",
"ExecutionContext",
"InvokeFrom",
"LayerExitSignals",
"PydanticAIStreamRunEvent",
"RUN_EVENT_ADAPTER",

View File

@ -2,18 +2,12 @@
Only explicitly allowed provider type ids are constructible here. The default
provider set contains prompt layers, the optional pydantic-ai history layer, the
state-free Dify structured output layer, the Dify execution-context layer, and
the Dify plugin business-layer family:
- ``dify.execution_context`` for shared tenant/user/run daemon context,
- ``dify.plugin.llm`` for plugin-backed model selection, and
- ``dify.plugin.tools`` for prepared plugin tool exposure.
Public DTOs provide Dify context plus plugin/model/tool data, while server-only
plugin daemon settings are injected through the provider factory for
``DifyExecutionContextLayer``. The resulting ``Compositor`` remains Agenton
state-only: live resources such as the plugin daemon HTTP client are supplied
later by the runtime and never enter providers, layers, or session snapshots.
state-free Dify structured output layer, plus Dify plugin LLM layers. Public
DTOs provide tenant/plugin/model data, while server-only plugin daemon settings
are injected through the provider factory for ``DifyPluginLayer``. The resulting
``Compositor`` remains Agenton state-only: live resources such as the plugin
daemon HTTP client are supplied later by the runtime and never enter providers,
layers, or session snapshots.
"""
from collections.abc import Mapping, Sequence
@ -26,10 +20,9 @@ from agenton.layers.types import AllPromptTypes, AllToolTypes, AllUserPromptType
from agenton_collections.layers.pydantic_ai import PydanticAIHistoryLayer
from agenton_collections.layers.plain.basic import PromptLayer
from agenton_collections.transformers.pydantic_ai import PYDANTIC_AI_TRANSFORMERS
from dify_agent.layers.dify_plugin.configs import DifyPluginLayerConfig
from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer
from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer
from dify_agent.layers.execution_context.configs import DifyExecutionContextLayerConfig
from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer
from dify_agent.layers.dify_plugin.plugin_layer import DifyPluginLayer
from dify_agent.layers.output.output_layer import DifyOutputLayer
@ -47,15 +40,14 @@ def create_default_layer_providers(
LayerProvider.from_layer_type(PydanticAIHistoryLayer),
LayerProvider.from_layer_type(DifyOutputLayer),
LayerProvider.from_factory(
layer_type=DifyExecutionContextLayer,
create=lambda config: DifyExecutionContextLayer.from_config_with_settings(
DifyExecutionContextLayerConfig.model_validate(config),
layer_type=DifyPluginLayer,
create=lambda config: DifyPluginLayer.from_config_with_settings(
DifyPluginLayerConfig.model_validate(config),
daemon_url=plugin_daemon_url,
daemon_api_key=plugin_daemon_api_key,
),
),
LayerProvider.from_layer_type(DifyPluginLLMLayer),
LayerProvider.from_layer_type(DifyPluginToolsLayer),
)

View File

@ -5,11 +5,12 @@ The scheduler is intentionally process-local: it persists a run record, starts a
task registry. Redis remains the durable source for status and event streams, but
there is no Redis job queue or cross-process handoff. If the process crashes,
currently active runs are lost until an external operator marks or retries them.
Create-run requests are accepted once the scheduler is not stopping and storage
can persist the run record. Request-shaped execution failures are left to
``AgentRunRunner`` so bad compositions, ``on_exit`` policies, prompts,
structured-output schemas, or session snapshots become asynchronous
``run_failed`` outcomes instead of synchronous HTTP rejections.
Create-run validation enters a lightweight Agenton run before persistence so the
same transformed user prompts, temporary system-prompt history assembly,
optional structured output contract, and top-level ``on_exit`` policy used by
execution are checked without relying on removed session/control APIs; Dify's
default layers keep lifecycle hooks side-effect free so this validation does not
open plugin daemon clients.
"""
import asyncio
@ -20,10 +21,15 @@ from typing import Protocol
import httpx
from agenton.compositor import LayerProviderInput
from dify_agent.protocol.schemas import CreateRunRequest
from dify_agent.runtime.compositor_factory import create_default_layer_providers
from dify_agent.protocol.schemas import CreateRunRequest, normalize_composition
from dify_agent.runtime.agenton_validation import is_agenton_enter_validation_runtime_error
from dify_agent.runtime.compositor_factory import build_pydantic_ai_compositor, create_default_layer_providers
from dify_agent.runtime.event_sink import RunEventSink, emit_run_failed
from dify_agent.runtime.history import build_run_message_history, get_history_layer, validate_history_layer_composition
from dify_agent.runtime.layer_exit_signals import apply_layer_exit_signals, validate_layer_exit_signals
from dify_agent.runtime.output_type import resolve_run_output_contract, validate_output_layer_composition
from dify_agent.runtime.runner import AgentRunRunner
from dify_agent.runtime.user_prompt_validation import EMPTY_USER_PROMPTS_ERROR, has_non_blank_user_prompt
from dify_agent.server.schemas import RunRecord
logger = logging.getLogger(__name__)
@ -33,6 +39,10 @@ class SchedulerStoppingError(RuntimeError):
"""Raised when a create-run request arrives after shutdown has started."""
class RunRequestValidationError(ValueError):
"""Raised when a create-run request cannot produce an executable Agenton run."""
class RunStore(RunEventSink, Protocol):
"""Persistence boundary needed by the scheduler."""
@ -58,8 +68,9 @@ class RunScheduler:
``active_tasks`` is mutated only on the event loop that calls ``create_run``
and ``shutdown``. The task registry is not durable; it exists so the lifespan
hook can wait for in-flight work and mark cancelled runs failed before Redis is
closed. A lock guards the stopping flag, run persistence, and task
registration so shutdown cannot begin after a request is admitted.
closed. A lock guards the stopping flag, lightweight request validation, run
persistence, and task registration so shutdown cannot begin after a request is
admitted and no validation runs once stopping has been set.
"""
store: RunStore
@ -90,16 +101,15 @@ class RunScheduler:
self._lifecycle_lock = asyncio.Lock()
async def create_run(self, request: CreateRunRequest) -> RunRecord:
"""Persist and schedule one run in the current process.
"""Validate, persist, and schedule one run in the current process.
The returned record is already ``running``. The background task is removed
from ``active_tasks`` when it finishes, regardless of success or failure.
Request-shaped runtime failures are intentionally deferred to the runner so
callers can observe them through the normal event/status stream.
"""
async with self._lifecycle_lock:
if self.stopping:
raise SchedulerStoppingError("run scheduler is shutting down")
await validate_run_request(request, layer_providers=self.layer_providers)
record = await self.store.create_run()
task = asyncio.create_task(self._run_record(record, request), name=f"dify-agent-run-{record.run_id}")
self.active_tasks[record.run_id] = task
@ -154,4 +164,52 @@ class RunScheduler:
logger.exception("failed to mark cancelled run failed", extra={"run_id": run_id})
__all__ = ["RunScheduler", "SchedulerStoppingError"]
async def validate_run_request(
request: CreateRunRequest,
*,
layer_providers: tuple[LayerProviderInput, ...] | None = None,
) -> None:
"""Validate create-run semantics that require an entered Agenton run.
This boundary rejects unsupported output/history-layer graph shapes, unknown
``on_exit`` layer ids, effectively empty transformed user prompts, and known
enter-time snapshot lifecycle errors before the scheduler persists a run
record. It also exercises provider config validation, temporary
system-prompt history assembly, structured output contract construction, and
snapshot hydration without touching external services because Dify plugin
daemon clients are owned by the FastAPI lifespan, not Agenton lifecycle
hooks.
"""
resolved_layer_providers = layer_providers if layer_providers is not None else create_default_layer_providers()
entered_run = False
try:
validate_output_layer_composition(request.composition)
validate_history_layer_composition(request.composition)
graph_config, layer_configs = normalize_composition(request.composition)
compositor = build_pydantic_ai_compositor(
graph_config,
providers=resolved_layer_providers,
)
validate_layer_exit_signals(compositor, request.on_exit)
async with compositor.enter(configs=layer_configs, session_snapshot=request.session_snapshot) as run:
entered_run = True
apply_layer_exit_signals(run, request.on_exit)
history_layer = get_history_layer(run)
_ = await build_run_message_history(
system_prompts=run.prompts,
stored_history=history_layer.message_history if history_layer is not None else (),
)
if not has_non_blank_user_prompt(run.user_prompts):
raise RunRequestValidationError(EMPTY_USER_PROMPTS_ERROR)
_ = resolve_run_output_contract(run)
except RunRequestValidationError:
raise
except RuntimeError as exc:
if not entered_run and is_agenton_enter_validation_runtime_error(exc):
raise RunRequestValidationError(str(exc)) from exc
raise
except (KeyError, TypeError, ValueError) as exc:
raise RunRequestValidationError(str(exc)) from exc
__all__ = ["RunRequestValidationError", "RunScheduler", "SchedulerStoppingError", "validate_run_request"]

View File

@ -21,17 +21,14 @@ snapshot; there are no separate output or snapshot events to correlate.
"""
from collections.abc import AsyncIterable
from collections import Counter
from typing import Any, cast
from typing import cast
import httpx
from pydantic import JsonValue, TypeAdapter
from pydantic_ai.messages import AgentStreamEvent
from agenton.compositor import CompositorSessionSnapshot, LayerProviderInput
from agenton.layers.types import PydanticAITool
from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer
from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer
from dify_agent.protocol.schemas import DIFY_AGENT_MODEL_LAYER_ID, CreateRunRequest, normalize_composition
from dify_agent.runtime.agent_factory import create_agent, normalize_user_input
from dify_agent.runtime.agenton_validation import is_agenton_enter_validation_runtime_error
@ -152,13 +149,12 @@ class AgentRunRunner:
)
llm_layer = run.get_layer(DIFY_AGENT_MODEL_LAYER_ID, DifyPluginLLMLayer)
model = llm_layer.get_model(http_client=self.plugin_daemon_http_client)
tools = await _resolve_run_tools(run, http_client=self.plugin_daemon_http_client)
except (KeyError, TypeError, RuntimeError, ValueError) as exc:
raise AgentRunValidationError(str(exc)) from exc
agent = create_agent(
model,
tools=tools,
tools=run.tools,
output_type=output_contract.output_type,
)
result = await agent.run(
@ -184,27 +180,4 @@ def _serialize_agent_output(output: object) -> JsonValue:
return cast(JsonValue, _AGENT_OUTPUT_ADAPTER.dump_python(output, mode="json"))
async def _resolve_run_tools(
run: Any,
*,
http_client: httpx.AsyncClient,
) -> list[PydanticAITool[object]]:
"""Return the static compositor tools plus any Dify plugin runtime tools."""
resolved_tools = list(cast(list[PydanticAITool[object]], run.tools))
for slot in run.slots.values():
layer = slot.layer
if isinstance(layer, DifyPluginToolsLayer):
resolved_tools.extend(await layer.get_tools(http_client=http_client))
_validate_unique_tool_names(resolved_tools)
return resolved_tools
def _validate_unique_tool_names(tools: list[PydanticAITool[object]]) -> None:
"""Reject duplicate tool names across static and dynamic tool sources."""
duplicate_names = sorted(name for name, count in Counter(tool.name for tool in tools).items() if count > 1)
if duplicate_names:
names = ", ".join(duplicate_names)
raise ValueError(f"Agent run requires unique tool names across all layers, got duplicates: {names}.")
__all__ = ["AgentRunRunner", "AgentRunValidationError"]

View File

@ -1,8 +1,8 @@
"""Validation for effective user prompts produced by Agenton runs.
Validation happens after safe compositor construction and run entry so runtime
execution uses the same transformed prompts as the actual pydantic-ai input.
Blank string fragments do not count as meaningful input; non-string
Validation happens after safe compositor construction and run entry so scheduler
and runner paths use the same transformed prompts as the actual pydantic-ai
input. Blank string fragments do not count as meaningful input; non-string
``UserContent`` is treated as intentional content because rich media/message
parts do not have a universal whitespace representation.
"""

View File

@ -1,13 +1,10 @@
"""FastAPI routes for asynchronous agent runs.
Controllers translate shutdown errors into HTTP status codes. Runtime request
failures are intentionally not pre-mapped here: once a request passes DTO
validation it is accepted for background execution, and bad compositions or
snapshots fail later through normal run events/status. Unexpected scheduler or
storage failures are intentionally left for FastAPI's server-error handling so
infrastructure problems are not reported as client input errors. Created runs
are scheduled in the current process and observed through status polling or SSE
replay backed by Redis event streams.
Controllers translate known validation and shutdown errors into HTTP status codes.
Unexpected scheduler or storage failures are intentionally left for FastAPI's
server-error handling so infrastructure problems are not reported as client input
errors. Created runs are scheduled in the current process and observed through
status polling or SSE replay backed by Redis event streams.
"""
from collections.abc import Callable
@ -24,7 +21,7 @@ from dify_agent.protocol.schemas import (
RunEventsResponse,
RunStatusResponse,
)
from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError
from dify_agent.runtime.run_scheduler import RunRequestValidationError, RunScheduler, SchedulerStoppingError
from dify_agent.server.sse import sse_event_stream
from dify_agent.storage.redis_run_store import RedisRunStore, RunNotFoundError
@ -49,6 +46,8 @@ def create_runs_router(
) -> CreateRunResponse:
try:
record = await scheduler.create_run(request)
except RunRequestValidationError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except SchedulerStoppingError as exc:
raise HTTPException(status_code=503, detail="run scheduler is shutting down") from exc
return CreateRunResponse(run_id=record.run_id, status=record.status)

View File

@ -1,5 +1,3 @@
from types import SimpleNamespace
import pytest
from pydantic import ValidationError
@ -7,54 +5,55 @@ import dify_agent.layers.dify_plugin as dify_plugin_exports
from dify_agent.layers.dify_plugin import (
DifyPluginCredentialValue,
DifyPluginLLMLayerConfig,
DifyPluginToolCredentialType,
DifyPluginToolConfig,
DifyPluginToolParameter,
DifyPluginToolParameterForm,
DifyPluginToolParameterType,
DifyPluginToolsLayerConfig,
DifyPluginToolValue,
DifyPluginLayerConfig,
)
def test_dify_plugin_package_exports_client_safe_config_symbols_only() -> None:
assert dify_plugin_exports.__all__ == [
"DIFY_PLUGIN_LAYER_TYPE_ID",
"DIFY_PLUGIN_LLM_LAYER_TYPE_ID",
"DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID",
"DifyPluginCredentialValue",
"DifyPluginLLMLayerConfig",
"DifyPluginToolCredentialType",
"DifyPluginToolConfig",
"DifyPluginToolOption",
"DifyPluginToolParameter",
"DifyPluginToolParameterForm",
"DifyPluginToolParameterType",
"DifyPluginToolsLayerConfig",
"DifyPluginToolValue",
"DifyPluginLayerConfig",
]
assert dify_plugin_exports.DIFY_PLUGIN_LAYER_TYPE_ID == "dify.plugin"
assert dify_plugin_exports.DIFY_PLUGIN_LLM_LAYER_TYPE_ID == "dify.plugin.llm"
assert dify_plugin_exports.DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID == "dify.plugin.tools"
assert not hasattr(dify_plugin_exports, "DifyPluginLayer")
assert not hasattr(dify_plugin_exports, "DifyPluginLLMLayer")
def test_dify_plugin_layer_config_forbids_runtime_settings() -> None:
config = DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="plugin-1", user_id="user-1")
assert config.tenant_id == "tenant-1"
assert config.plugin_id == "plugin-1"
assert config.user_id == "user-1"
with pytest.raises(ValidationError):
_ = DifyPluginLayerConfig.model_validate(
{
"tenant_id": "tenant-1",
"plugin_id": "plugin-1",
"daemon_url": "http://daemon",
}
)
def test_dify_plugin_llm_config_accepts_scalar_credentials_and_model_settings() -> None:
credential: DifyPluginCredentialValue = "secret"
config = DifyPluginLLMLayerConfig(
plugin_id="langgenius/openai",
model_provider="openai",
model="gpt-4o-mini",
credentials={"api_key": credential, "enabled": True, "retries": 2, "ratio": 0.5, "empty": None},
model_settings={"temperature": 0.2, "max_tokens": 64},
)
assert config.plugin_id == "langgenius/openai"
assert config.model_provider == "openai"
assert config.credentials == {"api_key": "secret", "enabled": True, "retries": 2, "ratio": 0.5, "empty": None}
assert config.model_settings == {"temperature": 0.2, "max_tokens": 64}
with pytest.raises(ValidationError):
_ = DifyPluginLLMLayerConfig.model_validate(
{
"plugin_id": "langgenius/openai",
"model_provider": "openai",
"model": "gpt-4o-mini",
"credentials": {"nested": {"not": "allowed"}},
@ -67,154 +66,6 @@ def test_dify_plugin_llm_config_rejects_old_provider_field() -> None:
_ = DifyPluginLLMLayerConfig.model_validate(
{
"provider": "openai",
"plugin_id": "langgenius/openai",
"model": "gpt-4o-mini",
}
)
def test_dify_plugin_tools_layer_config_accepts_prepared_parameters_and_schema() -> None:
runtime_value: DifyPluginToolValue = {"locale": "en-US", "max_results": 5}
credential_type: DifyPluginToolCredentialType = "api-key"
config = DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/tools",
provider="search",
tool_name="web_search",
credential_type=credential_type,
name="search_web",
description="Search the web.",
credentials={"api_key": "secret"},
runtime_parameters={"settings": runtime_value},
parameters=[
DifyPluginToolParameter(
name="query",
type=DifyPluginToolParameterType.STRING,
form=DifyPluginToolParameterForm.LLM,
required=True,
llm_description="Search query",
)
],
parameters_json_schema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
},
"required": ["query"],
},
)
]
)
assert config.tools[0].plugin_id == "langgenius/tools"
assert config.tools[0].provider == "search"
assert config.tools[0].tool_name == "web_search"
assert config.tools[0].credential_type == "api-key"
assert config.tools[0].name == "search_web"
assert config.tools[0].runtime_parameters == {"settings": {"locale": "en-US", "max_results": 5}}
assert config.tools[0].parameters[0].name == "query"
assert config.tools[0].parameters_json_schema["required"] == ["query"]
def test_dify_plugin_tool_parameter_accepts_api_tool_parameter_dump_shape() -> None:
parameter = DifyPluginToolParameter.model_validate(
{
"name": "query",
"label": {"en_US": "Query"},
"placeholder": None,
"human_description": {"en_US": "Visible in UI"},
"type": "select",
"form": "llm",
"required": True,
"default": "dify",
"llm_description": "Search query",
"input_schema": {"type": "string"},
"options": [
{
"value": "dify",
"label": {"en_US": "Dify"},
}
],
}
)
assert parameter.name == "query"
assert parameter.type is DifyPluginToolParameterType.SELECT
assert parameter.form is DifyPluginToolParameterForm.LLM
assert parameter.required is True
assert parameter.default == "dify"
assert parameter.input_schema == {"type": "string"}
assert [option.value for option in parameter.options] == ["dify"]
def test_dify_plugin_tool_parameter_accepts_api_tool_parameter_attributes() -> None:
parameter = DifyPluginToolParameter.model_validate(
SimpleNamespace(
name="language",
label=SimpleNamespace(en_US="Language"),
type="string",
form="form",
required=False,
default="en",
llm_description=None,
input_schema=None,
options=[SimpleNamespace(value="en", label=SimpleNamespace(en_US="English"))],
)
)
assert parameter.name == "language"
assert parameter.type is DifyPluginToolParameterType.STRING
assert parameter.form is DifyPluginToolParameterForm.FORM
assert parameter.default == "en"
assert [option.value for option in parameter.options] == ["en"]
def test_dify_plugin_tool_config_rejects_non_json_runtime_parameters() -> None:
with pytest.raises(ValidationError):
_ = DifyPluginToolConfig.model_validate(
{
"plugin_id": "langgenius/tools",
"provider": "search",
"tool_name": "web_search",
"credential_type": "api-key",
"runtime_parameters": {"bad": object()},
}
)
def test_dify_plugin_tool_config_rejects_non_json_schema_values() -> None:
with pytest.raises(ValidationError):
_ = DifyPluginToolConfig.model_validate(
{
"plugin_id": "langgenius/tools",
"provider": "search",
"tool_name": "web_search",
"credential_type": "api-key",
"parameters_json_schema": {"type": object()},
}
)
def test_dify_plugin_tool_config_rejects_strict_flag() -> None:
with pytest.raises(ValidationError):
_ = DifyPluginToolConfig.model_validate(
{
"plugin_id": "langgenius/tools",
"provider": "search",
"tool_name": "web_search",
"credential_type": "api-key",
"strict": True,
}
)
def test_dify_plugin_tool_config_requires_explicit_credential_type() -> None:
with pytest.raises(ValidationError):
_ = DifyPluginToolConfig.model_validate(
{
"plugin_id": "langgenius/tools",
"provider": "search",
"tool_name": "web_search",
}
)

View File

@ -1,36 +1,26 @@
import asyncio
import json
import httpx
import pytest
from pydantic import JsonValue
from agenton.compositor import Compositor, LayerNode, LayerProvider
from dify_agent.adapters.llm import DifyLLMAdapterModel
from dify_agent.layers.dify_plugin.configs import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
DifyPluginLLMLayerConfig,
DifyPluginToolConfig,
DifyPluginToolOption,
DifyPluginToolParameter,
DifyPluginToolParameterForm,
DifyPluginToolParameterType,
DifyPluginToolsLayerConfig,
DifyPluginLayerConfig,
)
from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer
from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer
from dify_agent.layers.dify_plugin.plugin_layer import DifyPluginLayer
def _execution_context_config() -> DifyExecutionContextLayerConfig:
return DifyExecutionContextLayerConfig(tenant_id="tenant-1", user_id="user-1", invoke_from="workflow_run")
def _plugin_config() -> DifyPluginLayerConfig:
return DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai", user_id="user-1")
def _llm_config() -> DifyPluginLLMLayerConfig:
return DifyPluginLLMLayerConfig(
plugin_id="langgenius/openai",
model_provider="openai",
model="demo-model",
credentials={"api_key": "secret"},
@ -38,192 +28,82 @@ def _llm_config() -> DifyPluginLLMLayerConfig:
)
def _tools_config() -> DifyPluginToolsLayerConfig:
return DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/tools",
provider="search",
tool_name="web_search",
credential_type="api-key",
description="Search the web.",
credentials={"api_key": "secret"},
runtime_parameters={"api_version": "2026-01", "auth_scope": "workspace"},
parameters=_prepared_tool_parameters(),
parameters_json_schema=_prepared_tool_schema(),
)
]
def _plugin_layer() -> DifyPluginLayer:
return DifyPluginLayer.from_config_with_settings(
_plugin_config(),
daemon_url="http://plugin-daemon",
daemon_api_key="daemon-secret",
)
def _missing_hidden_parameter_tools_config() -> DifyPluginToolsLayerConfig:
return DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/tools",
provider="search",
tool_name="web_search",
credential_type="api-key",
description="Search the web.",
credentials={"api_key": "secret"},
runtime_parameters={"api_version": "2026-01"},
parameters=_prepared_tool_parameters(),
parameters_json_schema=_prepared_tool_schema(),
)
]
)
def _execution_context_provider() -> LayerProvider[DifyExecutionContextLayer]:
def _plugin_provider() -> LayerProvider[DifyPluginLayer]:
return LayerProvider.from_factory(
layer_type=DifyExecutionContextLayer,
create=lambda config: DifyExecutionContextLayer.from_config_with_settings(
DifyExecutionContextLayerConfig.model_validate(config),
layer_type=DifyPluginLayer,
create=lambda config: DifyPluginLayer.from_config_with_settings(
DifyPluginLayerConfig.model_validate(config),
daemon_url="http://plugin-daemon",
daemon_api_key="daemon-secret",
),
)
def _prepared_tool_parameters() -> list[DifyPluginToolParameter]:
return [
DifyPluginToolParameter(
name="query",
type=DifyPluginToolParameterType.STRING,
form=DifyPluginToolParameterForm.LLM,
required=True,
llm_description="Search query",
),
DifyPluginToolParameter(
name="region",
type=DifyPluginToolParameterType.SELECT,
form=DifyPluginToolParameterForm.LLM,
required=False,
llm_description="Search region",
options=[DifyPluginToolOption(value="global"), DifyPluginToolOption(value="cn")],
),
DifyPluginToolParameter(
name="api_version",
type=DifyPluginToolParameterType.STRING,
form=DifyPluginToolParameterForm.FORM,
required=True,
llm_description="Hidden API version",
),
DifyPluginToolParameter(
name="auth_scope",
type=DifyPluginToolParameterType.STRING,
form=DifyPluginToolParameterForm.FORM,
required=True,
llm_description="Hidden auth scope",
),
]
def _prepared_tool_schema() -> dict[str, JsonValue]:
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
"region": {
"type": "string",
"description": "Search region",
"enum": ["global", "cn"],
},
},
"required": ["query"],
}
def _llm_only_parameter(*, name: str, description: str, default: JsonValue = None) -> DifyPluginToolParameter:
return DifyPluginToolParameter(
name=name,
type=DifyPluginToolParameterType.STRING,
form=DifyPluginToolParameterForm.LLM,
required=default is None,
default=default,
llm_description=description,
)
def _invoke_stream_response(
*,
error_payload: dict[str, object] | None = None,
chunked_blob: bool = False,
) -> httpx.Response:
if error_payload is not None:
return httpx.Response(400, json=error_payload)
if chunked_blob:
stream_payload = "\n".join(
[
f"data: {json.dumps({'code': 0, 'message': 'ok', 'data': {'type': 'blob_chunk', 'message': {'id': 'blob-1', 'sequence': 0, 'total_length': 11, 'blob': 'aGVsbG8g', 'end': False}}})}",
f"data: {json.dumps({'code': 0, 'message': 'ok', 'data': {'type': 'blob_chunk', 'message': {'id': 'blob-1', 'sequence': 1, 'total_length': 11, 'blob': 'd29ybGQ=', 'end': True}}})}",
"",
]
)
return httpx.Response(200, text=stream_payload)
stream_payload = "\n".join(
[
f"data: {json.dumps({'code': 0, 'message': 'ok', 'data': {'type': 'text', 'message': {'text': 'found '}}})}",
f"data: {json.dumps({'code': 0, 'message': 'ok', 'data': {'type': 'json', 'message': {'json_object': {'count': 1}}}})}",
"",
]
)
return httpx.Response(200, text=stream_payload)
def _tool_transport(
*,
invoke_error_payload: dict[str, object] | None = None,
chunked_blob: bool = False,
) -> httpx.MockTransport:
def handler(request: httpx.Request) -> httpx.Response:
if request.url.path.endswith("/dispatch/tool/invoke"):
payload = json.loads(request.content.decode("utf-8"))
assert payload["user_id"] == "user-1"
assert payload["data"]["provider"] == "search"
assert payload["data"]["tool"] == "web_search"
assert payload["data"]["credential_type"] == "api-key"
assert payload["data"]["tool_parameters"] == {
"query": "dify",
"region": "global",
"api_version": "2026-01",
"auth_scope": "workspace",
}
return _invoke_stream_response(error_payload=invoke_error_payload, chunked_blob=chunked_blob)
raise AssertionError(f"Unexpected request path: {request.url.path}")
return httpx.MockTransport(handler)
def test_dify_plugin_type_id_constants_match_implementation_classes() -> None:
assert DIFY_PLUGIN_LAYER_TYPE_ID == DifyPluginLayer.type_id
assert DIFY_PLUGIN_LLM_LAYER_TYPE_ID == DifyPluginLLMLayer.type_id
assert DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID == DifyPluginToolsLayer.type_id
def test_dify_plugin_layer_creates_daemon_provider_from_shared_http_client() -> None:
async def scenario() -> None:
plugin = _plugin_layer()
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client:
provider = plugin.create_daemon_provider(http_client=client)
assert provider.name == "DifyPlugin/langgenius/openai"
assert provider.client.http_client is client
assert provider.client.tenant_id == "tenant-1"
assert provider.client.plugin_id == "langgenius/openai"
assert provider.client.user_id == "user-1"
async with provider:
pass
assert client.is_closed is False
asyncio.run(scenario())
def test_dify_plugin_layer_rejects_closed_shared_http_client() -> None:
async def scenario() -> None:
plugin = _plugin_layer()
client = httpx.AsyncClient()
await client.aclose()
with pytest.raises(RuntimeError, match="open shared HTTP client"):
_ = plugin.create_daemon_provider(http_client=client)
asyncio.run(scenario())
def test_dify_plugin_llm_layer_builds_adapter_model_from_direct_dependency() -> None:
async def scenario() -> None:
compositor = Compositor(
[
LayerNode("renamed-execution-context", _execution_context_provider()),
LayerNode("llm", DifyPluginLLMLayer, deps={"execution_context": "renamed-execution-context"}),
LayerNode("renamed-plugin", _plugin_provider()),
LayerNode("llm", DifyPluginLLMLayer, deps={"plugin": "renamed-plugin"}),
]
)
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client:
async with compositor.enter(
configs={
"renamed-execution-context": _execution_context_config(),
"renamed-plugin": _plugin_config(),
"llm": _llm_config(),
}
) as run:
execution_context = run.get_layer("renamed-execution-context", DifyExecutionContextLayer)
plugin = run.get_layer("renamed-plugin", DifyPluginLayer)
llm = run.get_layer("llm", DifyPluginLLMLayer)
model = llm.get_model(http_client=client)
assert llm.deps.execution_context is execution_context
assert llm.deps.plugin is plugin
assert isinstance(model, DifyLLMAdapterModel)
assert model.model_name == "demo-model"
assert model.model_provider == "openai"
@ -234,436 +114,17 @@ def test_dify_plugin_llm_layer_builds_adapter_model_from_direct_dependency() ->
asyncio.run(scenario())
def test_dify_plugin_tools_layer_uses_prepared_tool_definition_and_invokes_daemon() -> None:
def test_dify_plugin_layer_lifecycle_does_not_manage_http_client() -> None:
async def scenario() -> None:
compositor = Compositor(
[
LayerNode("execution_context", _execution_context_provider()),
LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}),
]
)
async with httpx.AsyncClient(transport=_tool_transport()) as client:
async with compositor.enter(
configs={"execution_context": _execution_context_config(), "tools": _tools_config()}
) as run:
tools_layer = run.get_layer("tools", DifyPluginToolsLayer)
tool = (await tools_layer.get_tools(http_client=client))[0]
compositor = Compositor([LayerNode("plugin", _plugin_provider())])
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client:
async with compositor.enter(configs={"plugin": _plugin_config()}) as run:
plugin = run.get_layer("plugin", DifyPluginLayer)
provider = plugin.create_daemon_provider(http_client=client)
run.suspend_layer_on_exit("plugin")
tool_def = await tool.prepare_tool_def(None) # pyright: ignore[reportArgumentType]
result = await tool.function_schema.call(
{"query": "dify", "region": "global"},
None, # pyright: ignore[reportArgumentType]
)
assert tool.name == "web_search"
assert tool.description == "Search the web."
assert tool_def is not None
assert tool_def.parameters_json_schema == _prepared_tool_schema()
assert tool_def.strict is False
assert result == 'found {"count": 1}'
asyncio.run(scenario())
def test_dify_plugin_tools_layer_uses_each_tool_plugin_id_for_transport() -> None:
async def scenario() -> None:
seen_requests: list[tuple[str, str, str, str]] = []
def handler(request: httpx.Request) -> httpx.Response:
if request.url.path.endswith("/dispatch/tool/invoke"):
payload = json.loads(request.content.decode("utf-8"))
seen_requests.append(
(
request.headers["X-Plugin-ID"],
payload["user_id"],
payload["data"]["provider"],
payload["data"]["tool"],
)
)
return _invoke_stream_response()
raise AssertionError(f"Unexpected request path: {request.url.path}")
tools_config = DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/tools-a",
provider="search-a",
tool_name="web_search_a",
credential_type="api-key",
parameters=[_llm_only_parameter(name="query", description="Search query A")],
parameters_json_schema={
"type": "object",
"properties": {"query": {"type": "string", "description": "Search query A"}},
"required": ["query"],
},
),
DifyPluginToolConfig(
plugin_id="langgenius/tools-b",
provider="search-b",
tool_name="web_search_b",
credential_type="api-key",
parameters=[_llm_only_parameter(name="query", description="Search query B")],
parameters_json_schema={
"type": "object",
"properties": {"query": {"type": "string", "description": "Search query B"}},
"required": ["query"],
},
),
]
)
compositor = Compositor(
[
LayerNode("execution_context", _execution_context_provider()),
LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}),
]
)
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
async with compositor.enter(
configs={"execution_context": _execution_context_config(), "tools": tools_config}
) as run:
tools = await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client)
await tools[0].function_schema.call({"query": "first"}, None) # pyright: ignore[reportArgumentType]
await tools[1].function_schema.call({"query": "second"}, None) # pyright: ignore[reportArgumentType]
assert seen_requests == [
("langgenius/tools-a", "user-1", "search-a", "web_search_a"),
("langgenius/tools-b", "user-1", "search-b", "web_search_b"),
]
asyncio.run(scenario())
def test_dify_plugin_tools_layer_casts_prepared_parameter_values_before_invocation() -> None:
async def scenario() -> None:
def handler(request: httpx.Request) -> httpx.Response:
if request.url.path.endswith("/dispatch/tool/invoke"):
payload = json.loads(request.content.decode("utf-8"))
assert payload["user_id"] == "user-1"
assert payload["data"]["tool_parameters"] == {
"enabled": True,
"count": 7,
"tags": ["a", "b"],
"metadata": {"source": "docs"},
"model": {"provider": "openai", "model": "gpt-4o-mini"},
}
return _invoke_stream_response()
raise AssertionError(f"Unexpected request path: {request.url.path}")
tools_config = DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/tools",
provider="search",
tool_name="web_search",
credential_type="api-key",
parameters=[
DifyPluginToolParameter(
name="enabled",
type=DifyPluginToolParameterType.BOOLEAN,
form=DifyPluginToolParameterForm.LLM,
required=True,
llm_description="Enable search",
),
DifyPluginToolParameter(
name="count",
type=DifyPluginToolParameterType.NUMBER,
form=DifyPluginToolParameterForm.LLM,
required=True,
llm_description="Result count",
),
DifyPluginToolParameter(
name="tags",
type=DifyPluginToolParameterType.ARRAY,
form=DifyPluginToolParameterForm.LLM,
required=True,
llm_description="Tags",
input_schema={"type": "array", "items": {"type": "string"}},
),
DifyPluginToolParameter(
name="metadata",
type=DifyPluginToolParameterType.OBJECT,
form=DifyPluginToolParameterForm.LLM,
required=True,
llm_description="Metadata",
input_schema={"type": "object", "additionalProperties": True},
),
DifyPluginToolParameter(
name="model",
type=DifyPluginToolParameterType.MODEL_SELECTOR,
form=DifyPluginToolParameterForm.LLM,
required=True,
llm_description="Model selector",
input_schema={"type": "object", "additionalProperties": True},
),
],
parameters_json_schema={
"type": "object",
"properties": {
"enabled": {"type": "boolean", "description": "Enable search"},
"count": {"type": "number", "description": "Result count"},
"tags": {"type": "array", "items": {"type": "string"}, "description": "Tags"},
"metadata": {
"type": "object",
"additionalProperties": True,
"description": "Metadata",
},
"model": {
"type": "object",
"additionalProperties": True,
"description": "Model selector",
},
},
"required": ["enabled", "count", "tags", "metadata", "model"],
},
)
]
)
compositor = Compositor(
[
LayerNode("execution_context", _execution_context_provider()),
LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}),
]
)
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
async with compositor.enter(
configs={"execution_context": _execution_context_config(), "tools": tools_config}
) as run:
tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0]
result = await tool.function_schema.call(
{
"enabled": "yes",
"count": "7",
"tags": '["a", "b"]',
"metadata": '{"source": "docs"}',
"model": {"provider": "openai", "model": "gpt-4o-mini"},
},
None, # pyright: ignore[reportArgumentType]
)
assert result == 'found {"count": 1}'
asyncio.run(scenario())
def test_dify_plugin_tools_layer_sends_prepared_parameter_defaults_to_daemon() -> None:
async def scenario() -> None:
def handler(request: httpx.Request) -> httpx.Response:
if request.url.path.endswith("/dispatch/tool/invoke"):
payload = json.loads(request.content.decode("utf-8"))
assert payload["data"]["tool_parameters"] == {
"query": "dify",
"region": "global",
}
return _invoke_stream_response()
raise AssertionError(f"Unexpected request path: {request.url.path}")
tools_config = DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/tools",
provider="search",
tool_name="web_search",
credential_type="api-key",
parameters=[
_llm_only_parameter(name="query", description="Search query"),
_llm_only_parameter(name="region", description="Search region", default="global"),
],
parameters_json_schema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
"region": {"type": "string", "description": "Search region"},
},
"required": ["query"],
},
)
]
)
compositor = Compositor(
[
LayerNode("execution_context", _execution_context_provider()),
LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}),
]
)
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
async with compositor.enter(
configs={"execution_context": _execution_context_config(), "tools": tools_config}
) as run:
tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0]
result = await tool.function_schema.call(
{"query": "dify"},
None, # pyright: ignore[reportArgumentType]
)
assert result == 'found {"count": 1}'
asyncio.run(scenario())
def test_dify_plugin_tools_layer_requires_hidden_runtime_parameters_in_prepared_config() -> None:
async def scenario() -> None:
compositor = Compositor(
[
LayerNode("execution_context", _execution_context_provider()),
LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}),
]
)
async with httpx.AsyncClient(transport=_tool_transport()) as client:
async with compositor.enter(
configs={
"execution_context": _execution_context_config(),
"tools": _missing_hidden_parameter_tools_config(),
}
) as run:
with pytest.raises(ValueError, match="requires non-LLM runtime_parameters for: auth_scope"):
await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client)
asyncio.run(scenario())
def test_dify_plugin_tools_layer_returns_agent_friendly_error_text() -> None:
async def scenario() -> None:
compositor = Compositor(
[
LayerNode("execution_context", _execution_context_provider()),
LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}),
]
)
async with httpx.AsyncClient(
transport=_tool_transport(
invoke_error_payload={
"error_type": "PluginDaemonBadRequestError",
"message": "missing query",
}
)
) as client:
async with compositor.enter(
configs={"execution_context": _execution_context_config(), "tools": _tools_config()}
) as run:
tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0]
result = await tool.function_schema.call(
{"query": "dify", "region": "global"},
None, # pyright: ignore[reportArgumentType]
)
assert result == "tool parameters validation error: missing query, please check your tool parameters"
asyncio.run(scenario())
def test_dify_plugin_tools_layer_propagates_unexpected_transport_errors() -> None:
async def scenario() -> None:
compositor = Compositor(
[
LayerNode("execution_context", _execution_context_provider()),
LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}),
]
)
def handler(request: httpx.Request) -> httpx.Response:
if request.url.path.endswith("/dispatch/tool/invoke"):
raise RuntimeError("unexpected transport failure")
raise AssertionError(f"Unexpected request path: {request.url.path}")
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
async with compositor.enter(
configs={"execution_context": _execution_context_config(), "tools": _tools_config()}
) as run:
tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0]
with pytest.raises(RuntimeError, match="unexpected transport failure"):
await tool.function_schema.call(
{"query": "dify", "region": "global"},
None, # pyright: ignore[reportArgumentType]
)
asyncio.run(scenario())
@pytest.mark.parametrize(
("invoke_error_payload", "expected_text"),
[
(
{
"error_type": "PluginInvokeError",
"message": json.dumps(
{
"error_type": "PluginDaemonUnauthorizedError",
"message": "invalid api key",
}
),
},
"Please check your tool provider credentials",
),
(
{
"error_type": "PluginInvokeError",
"message": json.dumps(
{
"error_type": "ToolNotFoundError",
"message": "missing plugin tool",
}
),
},
"there is not a tool named web_search",
),
],
)
def test_dify_plugin_tools_layer_maps_nested_plugin_invoke_errors_to_agent_text(
invoke_error_payload: dict[str, object],
expected_text: str,
) -> None:
async def scenario() -> None:
compositor = Compositor(
[
LayerNode("execution_context", _execution_context_provider()),
LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}),
]
)
async with httpx.AsyncClient(transport=_tool_transport(invoke_error_payload=invoke_error_payload)) as client:
async with compositor.enter(
configs={"execution_context": _execution_context_config(), "tools": _tools_config()}
) as run:
tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0]
result = await tool.function_schema.call(
{"query": "dify", "region": "global"},
None, # pyright: ignore[reportArgumentType]
)
assert result == expected_text
asyncio.run(scenario())
def test_dify_plugin_tools_layer_merges_blob_chunks_before_observation_conversion() -> None:
async def scenario() -> None:
compositor = Compositor(
[
LayerNode("execution_context", _execution_context_provider()),
LayerNode("tools", DifyPluginToolsLayer, deps={"execution_context": "execution_context"}),
]
)
async with httpx.AsyncClient(transport=_tool_transport(chunked_blob=True)) as client:
async with compositor.enter(
configs={"execution_context": _execution_context_config(), "tools": _tools_config()}
) as run:
tool = (await run.get_layer("tools", DifyPluginToolsLayer).get_tools(http_client=client))[0]
result = await tool.function_schema.call(
{"query": "dify", "region": "global"},
None, # pyright: ignore[reportArgumentType]
)
assert "hello world" in result
assert "sequence=0" not in result
assert run.session_snapshot is not None
assert provider.client.http_client is client
assert client.is_closed is False
asyncio.run(scenario())

View File

@ -1,47 +0,0 @@
import pytest
from pydantic import ValidationError
import dify_agent.layers.execution_context as execution_context_exports
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
def test_execution_context_package_exports_client_safe_config_symbols_only() -> None:
assert execution_context_exports.__all__ == [
"DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID",
"DifyExecutionContextInvokeFrom",
"DifyExecutionContextLayerConfig",
]
assert execution_context_exports.DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID == "dify.execution_context"
assert not hasattr(execution_context_exports, "DifyExecutionContextLayer")
def test_execution_context_layer_config_forbids_runtime_settings_and_unknown_fields() -> None:
config = DifyExecutionContextLayerConfig(
tenant_id="tenant-1",
user_id="user-1",
workflow_id="workflow-1",
invoke_from="workflow_run",
)
assert config.tenant_id == "tenant-1"
assert config.user_id == "user-1"
assert config.workflow_id == "workflow-1"
assert config.invoke_from == "workflow_run"
with pytest.raises(ValidationError):
_ = DifyExecutionContextLayerConfig.model_validate(
{
"tenant_id": "tenant-1",
"invoke_from": "workflow_run",
"daemon_url": "http://daemon",
}
)
with pytest.raises(ValidationError):
_ = DifyExecutionContextLayerConfig.model_validate(
{
"tenant_id": "tenant-1",
"invoke_from": "workflow_run",
"unknown": "value",
}
)

View File

@ -1,107 +0,0 @@
import asyncio
import httpx
import pytest
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer
def _execution_context_layer() -> DifyExecutionContextLayer:
return DifyExecutionContextLayer.from_config_with_settings(
DifyExecutionContextLayerConfig(tenant_id="tenant-1", user_id="user-1", invoke_from="workflow_run"),
daemon_url="http://plugin-daemon",
daemon_api_key="daemon-secret",
)
def test_execution_context_type_id_constant_matches_implementation_class() -> None:
assert DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID == DifyExecutionContextLayer.type_id
def test_execution_context_layer_creates_daemon_provider_from_shared_http_client() -> None:
async def scenario() -> None:
execution_context = _execution_context_layer()
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client:
provider = execution_context.create_daemon_provider(plugin_id="langgenius/openai", http_client=client)
assert provider.name == "DifyPlugin/langgenius/openai"
assert provider.client.http_client is client
assert provider.client.tenant_id == "tenant-1"
assert provider.client.plugin_id == "langgenius/openai"
assert provider.client.user_id == "user-1"
async with provider:
pass
assert client.is_closed is False
asyncio.run(scenario())
def test_execution_context_layer_creates_tool_client_from_shared_http_client() -> None:
async def scenario() -> None:
execution_context = _execution_context_layer()
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client:
tool_client = execution_context.create_tool_client(plugin_id="langgenius/tools", http_client=client)
assert tool_client.http_client is client
assert tool_client.tenant_id == "tenant-1"
assert tool_client.user_id == "user-1"
assert tool_client.plugin_id == "langgenius/tools"
assert tool_client.plugin_daemon_url == "http://plugin-daemon"
assert tool_client.plugin_daemon_api_key == "daemon-secret"
assert client.is_closed is False
asyncio.run(scenario())
def test_execution_context_layer_rejects_closed_shared_http_client() -> None:
async def scenario() -> None:
execution_context = _execution_context_layer()
client = httpx.AsyncClient()
await client.aclose()
with pytest.raises(RuntimeError, match="open shared HTTP client"):
_ = execution_context.create_daemon_provider(plugin_id="langgenius/openai", http_client=client)
with pytest.raises(RuntimeError, match="open shared HTTP client"):
_ = execution_context.create_tool_client(plugin_id="langgenius/tools", http_client=client)
asyncio.run(scenario())
def test_execution_context_layer_lifecycle_does_not_manage_http_client() -> None:
from agenton.compositor import Compositor, LayerNode, LayerProvider
provider = LayerProvider.from_factory(
layer_type=DifyExecutionContextLayer,
create=lambda config: DifyExecutionContextLayer.from_config_with_settings(
DifyExecutionContextLayerConfig.model_validate(config),
daemon_url="http://plugin-daemon",
daemon_api_key="daemon-secret",
),
)
async def scenario() -> None:
compositor = Compositor([LayerNode("execution_context", provider)])
async with httpx.AsyncClient(transport=httpx.MockTransport(lambda _request: httpx.Response(200))) as client:
async with compositor.enter(
configs={
"execution_context": DifyExecutionContextLayerConfig(
tenant_id="tenant-1",
user_id="user-1",
invoke_from="workflow_run",
)
}
) as run:
execution_context = run.get_layer("execution_context", DifyExecutionContextLayer)
daemon_provider = execution_context.create_daemon_provider(
plugin_id="langgenius/openai",
http_client=client,
)
run.suspend_layer_on_exit("execution_context")
assert run.session_snapshot is not None
assert daemon_provider.client.http_client is client
assert client.is_closed is False
asyncio.run(scenario())

View File

@ -6,13 +6,13 @@ from agenton.compositor import CompositorSessionSnapshot
from agenton.layers import ExitIntent
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
import dify_agent.protocol as protocol_exports
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
from dify_agent.layers.dify_plugin import DIFY_PLUGIN_LLM_LAYER_TYPE_ID, DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID
from dify_agent.layers.dify_plugin import DIFY_PLUGIN_LAYER_TYPE_ID, DIFY_PLUGIN_LLM_LAYER_TYPE_ID
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.protocol import DIFY_AGENT_HISTORY_LAYER_ID, DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID
from dify_agent.protocol.schemas import (
RUN_EVENT_ADAPTER,
CreateRunRequest,
ExecutionContext,
LayerExitSignals,
PydanticAIStreamRunEvent,
RunCancelledEvent,
@ -28,14 +28,7 @@ from dify_agent.protocol.schemas import (
RunSucceededEventData,
normalize_composition,
)
from dify_agent.layers.dify_plugin.configs import (
DifyPluginLLMLayerConfig,
DifyPluginToolConfig,
DifyPluginToolParameter,
DifyPluginToolParameterForm,
DifyPluginToolParameterType,
DifyPluginToolsLayerConfig,
)
from dify_agent.layers.dify_plugin.configs import DifyPluginLLMLayerConfig, DifyPluginLayerConfig
def test_run_event_adapter_round_trips_typed_variants() -> None:
@ -94,23 +87,10 @@ def test_create_run_request_rejects_old_compositor_payload_and_model_layer_id_is
)
def test_protocol_package_no_longer_exports_execution_context_dto() -> None:
assert not hasattr(protocol_exports, "ExecutionContext")
def test_create_run_request_accepts_dto_first_public_composition_and_normalizes_graph_config() -> None:
prompt_config = PromptLayerConfig(prefix="system", user="hello")
execution_context_config = DifyExecutionContextLayerConfig(
tenant_id="tenant-1",
workflow_id="workflow-1",
workflow_run_id="workflow-run-1",
node_id="node-1",
node_execution_id="node-execution-1",
invoke_from="workflow_run",
trace_id="trace-1",
)
plugin_config = DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai")
llm_config = DifyPluginLLMLayerConfig(
plugin_id="langgenius/openai",
model_provider="openai",
model="demo-model",
credentials={"api_key": "secret"},
@ -124,21 +104,26 @@ def test_create_run_request_accepts_dto_first_public_composition_and_normalizes_
}
)
request = CreateRunRequest(
execution_context=ExecutionContext(
tenant_id="tenant-1",
workflow_id="workflow-1",
workflow_run_id="workflow-run-1",
node_id="node-1",
node_execution_id="node-execution-1",
invoke_from="workflow_run",
trace_id="trace-1",
),
purpose="workflow_node",
idempotency_key="workflow-run-1:node-execution-1",
metadata={"source": "unit_test"},
composition=RunComposition(
layers=[
RunLayerSpec(name="prompt", type=PLAIN_PROMPT_LAYER_TYPE_ID, config=prompt_config),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=execution_context_config,
),
RunLayerSpec(name="plugin", type=DIFY_PLUGIN_LAYER_TYPE_ID, config=plugin_config),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
deps={"plugin": "plugin"},
config=llm_config,
),
RunLayerSpec(
@ -153,9 +138,8 @@ def test_create_run_request_accepts_dto_first_public_composition_and_normalizes_
graph_config, layer_configs = normalize_composition(request.composition)
payload = request.model_dump(mode="json")
assert payload["composition"]["layers"][1]["config"] == {
assert payload["execution_context"] == {
"tenant_id": "tenant-1",
"user_id": None,
"app_id": None,
"workflow_id": "workflow-1",
"workflow_run_id": "workflow-run-1",
@ -173,16 +157,11 @@ def test_create_run_request_accepts_dto_first_public_composition_and_normalizes_
assert payload["composition"]["layers"][0]["config"] == {"prefix": "system", "user": "hello", "suffix": []}
assert [layer.model_dump(mode="json") for layer in graph_config.layers] == [
{"name": "prompt", "type": PLAIN_PROMPT_LAYER_TYPE_ID, "deps": {}, "metadata": {}},
{
"name": "execution_context",
"type": DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
"deps": {},
"metadata": {},
},
{"name": "plugin", "type": DIFY_PLUGIN_LAYER_TYPE_ID, "deps": {}, "metadata": {}},
{
"name": DIFY_AGENT_MODEL_LAYER_ID,
"type": DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
"deps": {"execution_context": "execution_context"},
"deps": {"plugin": "plugin"},
"metadata": {},
},
{
@ -194,118 +173,12 @@ def test_create_run_request_accepts_dto_first_public_composition_and_normalizes_
]
assert layer_configs == {
"prompt": prompt_config,
"execution_context": execution_context_config,
"plugin": plugin_config,
DIFY_AGENT_MODEL_LAYER_ID: llm_config,
DIFY_AGENT_OUTPUT_LAYER_ID: output_config,
}
def test_create_run_request_accepts_plugin_tools_layer_with_prepared_parameters_and_schema() -> None:
request = CreateRunRequest.model_validate(
{
"composition": {
"layers": [
{"name": "prompt", "type": PLAIN_PROMPT_LAYER_TYPE_ID, "config": {"user": "hello"}},
{
"name": "execution_context",
"type": DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
"config": {"tenant_id": "tenant-1", "invoke_from": "workflow_run"},
},
{
"name": DIFY_AGENT_MODEL_LAYER_ID,
"type": DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
"deps": {"execution_context": "execution_context"},
"config": {
"plugin_id": "langgenius/openai",
"model_provider": "openai",
"model": "demo-model",
"credentials": {"api_key": "secret"},
},
},
{
"name": "tools",
"type": DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
"deps": {"execution_context": "execution_context"},
"config": {
"tools": [
{
"plugin_id": "langgenius/search",
"provider": "search",
"tool_name": "web_search",
"credential_type": "api-key",
"runtime_parameters": {"site": "docs.dify.ai"},
"parameters": [
{
"name": "query",
"type": "string",
"form": "llm",
"required": True,
"llm_description": "Search query",
},
{
"name": "site",
"type": "string",
"form": "form",
"required": True,
"llm_description": "Hidden site",
},
],
"parameters_json_schema": {
"type": "object",
"properties": {"query": {"type": "string", "description": "Search query"}},
"required": ["query"],
},
}
]
},
},
]
}
}
)
graph_config, layer_configs = normalize_composition(request.composition)
assert [layer.type for layer in graph_config.layers] == [
PLAIN_PROMPT_LAYER_TYPE_ID,
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
]
assert DifyPluginToolsLayerConfig.model_validate(layer_configs["tools"]) == DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/search",
provider="search",
tool_name="web_search",
credential_type="api-key",
runtime_parameters={"site": "docs.dify.ai"},
parameters=[
DifyPluginToolParameter(
name="query",
type=DifyPluginToolParameterType.STRING,
form=DifyPluginToolParameterForm.LLM,
required=True,
llm_description="Search query",
),
DifyPluginToolParameter(
name="site",
type=DifyPluginToolParameterType.STRING,
form=DifyPluginToolParameterForm.FORM,
required=True,
llm_description="Hidden site",
),
],
parameters_json_schema={
"type": "object",
"properties": {"query": {"type": "string", "description": "Search query"}},
"required": ["query"],
},
)
]
)
def test_on_exit_default_to_suspend_and_are_public() -> None:
assert protocol_exports.LayerExitSignals is LayerExitSignals
assert protocol_exports.RunComposition is RunComposition
@ -333,12 +206,13 @@ def test_on_exit_accept_layer_overrides() -> None:
assert request.on_exit.layers == {"prompt": ExitIntent.SUSPEND, "llm": ExitIntent.DELETE}
def test_create_run_request_rejects_removed_top_level_execution_context() -> None:
def test_execution_context_rejects_unknown_fields() -> None:
with pytest.raises(ValidationError):
_ = CreateRunRequest.model_validate(
_ = ExecutionContext.model_validate(
{
"composition": {"layers": []},
"execution_context": {"tenant_id": "tenant-1", "invoke_from": "workflow_run"},
"tenant_id": "tenant-1",
"invoke_from": "workflow_run",
"unknown": "value",
}
)

View File

@ -6,18 +6,25 @@ import httpx
import pytest
from agenton.compositor import CompositorSessionSnapshot, LayerSessionSnapshot
from agenton.layers import LifecycleState
from agenton.layers import ExitIntent, LifecycleState
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
from agenton_collections.layers.plain import PromptLayerConfig
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.protocol import DIFY_AGENT_OUTPUT_LAYER_ID
from dify_agent.protocol import DIFY_AGENT_HISTORY_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID
from dify_agent.protocol.schemas import (
CreateRunRequest,
LayerExitSignals,
RunComposition,
RunEvent,
RunLayerSpec,
RunStatus,
)
from dify_agent.runtime.run_scheduler import RunScheduler, SchedulerStoppingError
from dify_agent.runtime.run_scheduler import (
RunRequestValidationError,
RunScheduler,
SchedulerStoppingError,
validate_run_request,
)
from dify_agent.server.schemas import RunRecord
@ -161,64 +168,390 @@ def test_shutdown_marks_unfinished_runs_failed_and_appends_event() -> None:
asyncio.run(scenario())
def test_create_run_accepts_blank_prompt_and_runner_fails_asynchronously() -> None:
def test_create_run_rejects_blank_prompt_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
record = await scheduler.create_run(_request(["", " "]))
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
with pytest.raises(ValueError, match="run.user_prompts must not be empty"):
await scheduler.create_run(_request(["", " "]))
assert store.records == {record.run_id: record}
assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"]
assert store.statuses[record.run_id] == "failed"
assert store.errors[record.run_id] == "run.user_prompts must not be empty"
assert store.records == {}
asyncio.run(scenario())
def test_create_run_accepts_invalid_output_schema_and_runner_fails_asynchronously() -> None:
def test_create_run_rejects_invalid_output_schema_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
record = await scheduler.create_run(
_request(
output_config={
"json_schema": _recursive_output_schema(),
}
with pytest.raises(ValueError, match=r"Recursive \$defs refs are not supported"):
await scheduler.create_run(
_request(
output_config={
"json_schema": _recursive_output_schema(),
}
)
)
assert store.records == {}
asyncio.run(scenario())
def test_create_run_rejects_remote_ref_output_schema_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
with pytest.raises(ValueError, match=r"Remote \$ref values are not supported"):
await scheduler.create_run(
_request(
output_config={
"json_schema": {
"type": "object",
"properties": {
"title": {"$ref": "https://example.com/schema.json"},
},
},
}
)
)
assert store.records == {}
asyncio.run(scenario())
def test_create_run_rejects_non_object_output_schema_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
with pytest.raises(ValueError, match="Schema must declare an object output"):
await scheduler.create_run(
_request(
output_config={
"json_schema": {
"type": "array",
"items": {"type": "string"},
},
}
)
)
assert store.records == {}
asyncio.run(scenario())
def test_create_run_rejects_public_output_tool_name_override_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
with pytest.raises(ValueError, match="Extra inputs are not permitted"):
await scheduler.create_run(
_request(
output_config={
"name": "incident_summary",
"json_schema": {
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
},
}
)
)
assert store.records == {}
asyncio.run(scenario())
def test_create_run_rejects_non_defs_local_ref_in_direct_object_schema_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
with pytest.raises(ValueError, match=r"Only local refs under '#/\$defs/' are supported"):
await scheduler.create_run(
_request(
output_config={
"json_schema": {
"type": "object",
"properties": {
"items": {"$ref": "#/definitions/itemArray"},
},
"required": ["items"],
"definitions": {
"itemArray": {
"type": "array",
"items": {"type": "string"},
},
},
},
}
)
)
assert store.records == {}
asyncio.run(scenario())
def test_create_run_rejects_misnamed_output_layer_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")),
RunLayerSpec(
name="structured-output",
type=DIFY_OUTPUT_LAYER_TYPE_ID,
config=DifyOutputLayerConfig(
json_schema={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
}
),
),
]
)
)
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
assert store.records == {record.run_id: record}
assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"]
assert store.statuses[record.run_id] == "failed"
assert "Recursive $defs refs are not supported" in (store.errors[record.run_id] or "")
with pytest.raises(ValueError, match="must use reserved layer name 'output'"):
await scheduler.create_run(request)
assert store.records == {}
asyncio.run(scenario())
def test_create_run_honors_explicit_empty_layer_providers_by_failing_after_persisting() -> None:
def test_create_run_rejects_multiple_output_layers_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")),
RunLayerSpec(
name=DIFY_AGENT_OUTPUT_LAYER_ID,
type=DIFY_OUTPUT_LAYER_TYPE_ID,
config=DifyOutputLayerConfig(
json_schema={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
}
),
),
RunLayerSpec(
name="secondary-output",
type=DIFY_OUTPUT_LAYER_TYPE_ID,
config=DifyOutputLayerConfig(
json_schema={
"type": "object",
"properties": {"summary": {"type": "string"}},
"required": ["summary"],
"additionalProperties": False,
}
),
),
]
)
)
with pytest.raises(ValueError, match="Only one 'dify.output' layer is supported"):
await scheduler.create_run(request)
assert store.records == {}
asyncio.run(scenario())
def test_create_run_rejects_reserved_output_name_with_wrong_layer_type_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")),
RunLayerSpec(
name=DIFY_AGENT_OUTPUT_LAYER_ID, type="plain.prompt", config=PromptLayerConfig(user="hi")
),
]
)
)
with pytest.raises(ValueError, match=r"Layer 'output' must be DifyOutputLayer, got PromptLayer"):
await scheduler.create_run(request)
assert store.records == {}
asyncio.run(scenario())
def test_validate_run_request_honors_explicit_empty_layer_providers() -> None:
async def scenario() -> None:
with pytest.raises(RunRequestValidationError, match="plain.prompt"):
await validate_run_request(_request(), layer_providers=())
asyncio.run(scenario())
def test_validate_run_request_rejects_misnamed_output_layer_before_provider_checks() -> None:
async def scenario() -> None:
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")),
RunLayerSpec(
name="structured-output",
type=DIFY_OUTPUT_LAYER_TYPE_ID,
config=DifyOutputLayerConfig(
json_schema={
"type": "object",
"properties": {"title": {"type": "string"}},
"required": ["title"],
"additionalProperties": False,
}
),
),
]
)
)
with pytest.raises(RunRequestValidationError, match="must use reserved layer name 'output'"):
await validate_run_request(request, layer_providers=())
asyncio.run(scenario())
def test_validate_run_request_accepts_reserved_history_layer() -> None:
async def scenario() -> None:
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")),
RunLayerSpec(name=DIFY_AGENT_HISTORY_LAYER_ID, type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID),
]
)
)
await validate_run_request(request)
asyncio.run(scenario())
def test_validate_run_request_rejects_misnamed_history_layer_before_provider_checks() -> None:
async def scenario() -> None:
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")),
RunLayerSpec(name="chat-history", type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID),
]
)
)
with pytest.raises(RunRequestValidationError, match="must use reserved layer name 'history'"):
await validate_run_request(request, layer_providers=())
asyncio.run(scenario())
def test_validate_run_request_rejects_multiple_history_layers_before_provider_checks() -> None:
async def scenario() -> None:
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")),
RunLayerSpec(name=DIFY_AGENT_HISTORY_LAYER_ID, type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID),
RunLayerSpec(name="secondary-history", type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID),
]
)
)
with pytest.raises(RunRequestValidationError, match="Only one 'pydantic_ai.history' layer is supported"):
await validate_run_request(request, layer_providers=())
asyncio.run(scenario())
def test_validate_run_request_rejects_history_layer_dependencies_before_provider_checks() -> None:
async def scenario() -> None:
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(name="prompt", type="plain.prompt", config=PromptLayerConfig(user="hello")),
RunLayerSpec(
name=DIFY_AGENT_HISTORY_LAYER_ID,
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
deps={"prompt": "prompt"},
),
]
)
)
with pytest.raises(RunRequestValidationError, match="does not support dependencies"):
await validate_run_request(request, layer_providers=())
asyncio.run(scenario())
def test_create_run_rejects_unknown_layer_exit_signal_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
request = _request()
request.on_exit = LayerExitSignals(layers={"missing": ExitIntent.DELETE})
with pytest.raises(ValueError, match="missing"):
await scheduler.create_run(request)
assert store.records == {}
asyncio.run(scenario())
def test_create_run_honors_explicit_empty_layer_providers_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, layer_providers=())
record = await scheduler.create_run(_request())
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
with pytest.raises(RunRequestValidationError, match="plain.prompt"):
await scheduler.create_run(_request())
assert store.records == {record.run_id: record}
assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"]
assert store.statuses[record.run_id] == "failed"
assert "plain.prompt" in (store.errors[record.run_id] or "")
assert store.records == {}
asyncio.run(scenario())
def test_create_run_accepts_closed_session_snapshot_and_runner_fails_asynchronously() -> None:
def test_create_run_rejects_closed_session_snapshot_before_persisting() -> None:
async def scenario() -> None:
store = FakeStore()
async with httpx.AsyncClient() as client:
@ -234,13 +567,10 @@ def test_create_run_accepts_closed_session_snapshot_and_runner_fails_asynchronou
]
)
record = await scheduler.create_run(request)
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
with pytest.raises(ValueError, match="CLOSED snapshots cannot be entered"):
_ = await scheduler.create_run(request)
assert store.records == {record.run_id: record}
assert [event.type for event in store.events[record.run_id]] == ["run_started", "run_failed"]
assert store.statuses[record.run_id] == "failed"
assert "CLOSED snapshots cannot be entered" in (store.errors[record.run_id] or "")
assert store.records == {}
asyncio.run(scenario())

View File

@ -1,11 +1,9 @@
import asyncio
from collections.abc import Mapping
from typing import Any, ClassVar, cast
from typing import Any
import httpx
import pytest
from pydantic import JsonValue
from pydantic_ai import Tool
from pydantic_ai.exceptions import UnexpectedModelBehavior
from pydantic_ai.messages import (
ModelMessage,
@ -20,22 +18,12 @@ from pydantic_ai.models import ModelRequestParameters
from pydantic_ai.models.test import TestModel
from pydantic_ai.settings import ModelSettings
from agenton.compositor import CompositorSessionSnapshot, LayerProvider, LayerSessionSnapshot
from agenton.compositor import CompositorSessionSnapshot, LayerSessionSnapshot
from agenton.layers import ExitIntent, LifecycleState
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID, PydanticAIHistoryRuntimeState
from agenton_collections.layers.plain import PromptLayerConfig, ToolsLayer
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
from dify_agent.layers.dify_plugin.configs import (
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
DifyPluginLLMLayerConfig,
DifyPluginToolConfig,
DifyPluginToolParameter,
DifyPluginToolParameterForm,
DifyPluginToolParameterType,
DifyPluginToolsLayerConfig,
)
from agenton_collections.layers.plain import PromptLayerConfig
from dify_agent.layers.dify_plugin.configs import DifyPluginLLMLayerConfig, DifyPluginLayerConfig
from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer
from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.protocol import DIFY_AGENT_HISTORY_LAYER_ID, DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID
from dify_agent.protocol.schemas import (
@ -46,20 +34,15 @@ from dify_agent.protocol.schemas import (
RunSucceededEvent,
)
from dify_agent.runtime.event_sink import InMemoryRunEventSink
from dify_agent.runtime.compositor_factory import create_default_layer_providers
from dify_agent.runtime.runner import AgentRunRunner, AgentRunValidationError
class StaticToolsTestLayer(ToolsLayer):
type_id: ClassVar[str] = "test.static.tools"
def _request(
user: str | list[str] = "hello",
*,
include_history: bool = False,
llm_layer_name: str = DIFY_AGENT_MODEL_LAYER_ID,
execution_context_layer_name: str = "execution_context",
plugin_layer_name: str = "plugin",
on_exit: LayerExitSignals | None = None,
output_config: Mapping[str, object] | DifyOutputLayerConfig | None = None,
) -> CreateRunRequest:
@ -75,16 +58,15 @@ def _request(
else []
),
RunLayerSpec(
name=execution_context_layer_name,
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"),
name=plugin_layer_name,
type="dify.plugin",
config=DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai"),
),
RunLayerSpec(
name=llm_layer_name,
type="dify.plugin.llm",
deps={"execution_context": execution_context_layer_name},
deps={"plugin": plugin_layer_name},
config=DifyPluginLLMLayerConfig(
plugin_id="langgenius/openai",
model_provider="openai",
model="demo-model",
credentials={"api_key": "secret"},
@ -121,35 +103,6 @@ def _recursive_output_schema() -> dict[str, object]:
}
def _prepared_plugin_tool_parameters() -> list[DifyPluginToolParameter]:
return [
DifyPluginToolParameter(
name="query",
type=DifyPluginToolParameterType.STRING,
form=DifyPluginToolParameterForm.LLM,
required=True,
llm_description="Search query",
),
DifyPluginToolParameter(
name="auth_scope",
type=DifyPluginToolParameterType.STRING,
form=DifyPluginToolParameterForm.FORM,
required=True,
llm_description="Hidden auth scope",
),
]
def _prepared_plugin_tool_schema() -> dict[str, JsonValue]:
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
},
"required": ["query"],
}
class SequenceOutputTestModel(TestModel):
outputs: list[str | dict[str, Any] | None]
request_count: int
@ -217,7 +170,7 @@ def _history_session_snapshot(
lifecycle_state=LifecycleState.SUSPENDED,
runtime_state=PydanticAIHistoryRuntimeState(messages=messages).model_dump(mode="json"),
),
LayerSessionSnapshot(name="execution_context", lifecycle_state=LifecycleState.SUSPENDED, runtime_state={}),
LayerSessionSnapshot(name="plugin", lifecycle_state=LifecycleState.SUSPENDED, runtime_state={}),
LayerSessionSnapshot(
name=DIFY_AGENT_MODEL_LAYER_ID, lifecycle_state=LifecycleState.SUSPENDED, runtime_state={}
),
@ -245,12 +198,12 @@ def test_runner_emits_terminal_success_and_snapshot(monkeypatch: pytest.MonkeyPa
def fake_get_model(self: DifyPluginLLMLayer, *, http_client: httpx.AsyncClient):
assert self.config.model == "demo-model"
assert self.config.plugin_id == "langgenius/openai"
assert self.deps.plugin.config.plugin_id == "langgenius/openai"
seen_clients.append(http_client)
return TestModel(custom_output_text="done") # pyright: ignore[reportReturnType]
monkeypatch.setattr(DifyPluginLLMLayer, "get_model", fake_get_model)
request = _request(execution_context_layer_name="renamed-execution-context")
request = _request(plugin_layer_name="renamed-plugin")
sink = InMemoryRunEventSink()
async def scenario() -> None:
@ -277,7 +230,7 @@ def test_runner_emits_terminal_success_and_snapshot(monkeypatch: pytest.MonkeyPa
assert terminal.data.output == "done"
assert [layer.name for layer in terminal.data.session_snapshot.layers] == [
"prompt",
"renamed-execution-context",
"renamed-plugin",
DIFY_AGENT_MODEL_LAYER_ID,
]
assert [layer.lifecycle_state for layer in terminal.data.session_snapshot.layers] == [
@ -288,315 +241,6 @@ def test_runner_emits_terminal_success_and_snapshot(monkeypatch: pytest.MonkeyPa
assert sink.statuses["run-1"] == "succeeded"
def test_runner_passes_dynamic_dify_plugin_tools_to_agent(monkeypatch: pytest.MonkeyPatch) -> None:
seen_tools: list[Tool[object]] = []
async def plugin_tool() -> str:
return "tool"
def fake_get_model(_self: DifyPluginLLMLayer, *, http_client: httpx.AsyncClient):
assert http_client.is_closed is False
return TestModel(custom_output_text="done") # pyright: ignore[reportReturnType]
async def fake_get_tools(self: DifyPluginToolsLayer, *, http_client: httpx.AsyncClient) -> list[Tool[object]]:
assert self.config.tools[0].tool_name == "web_search"
assert http_client.is_closed is False
return [Tool(plugin_tool, name="web_search")]
class FakeResult:
output: str = "done"
def new_messages(self) -> list[ModelMessage]:
return []
class FakeAgent:
async def run(self, *_args: object, **_kwargs: object) -> FakeResult:
return FakeResult()
def fake_create_agent(model: object, *, tools: list[Tool[object]], output_type: object) -> FakeAgent:
del model, output_type
seen_tools.extend(tools)
return FakeAgent()
monkeypatch.setattr(DifyPluginLLMLayer, "get_model", fake_get_model)
monkeypatch.setattr(DifyPluginToolsLayer, "get_tools", fake_get_tools)
monkeypatch.setattr("dify_agent.runtime.runner.create_agent", fake_create_agent)
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(
name="prompt",
type="plain.prompt",
config=PromptLayerConfig(prefix="system", user="hello"),
),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type="dify.plugin.llm",
deps={"execution_context": "execution_context"},
config=DifyPluginLLMLayerConfig(
plugin_id="langgenius/openai",
model_provider="openai",
model="demo-model",
credentials={"api_key": "secret"},
),
),
RunLayerSpec(
name="tools",
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
config=DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/tools",
provider="search",
tool_name="web_search",
credential_type="api-key",
parameters=_prepared_plugin_tool_parameters(),
parameters_json_schema=_prepared_plugin_tool_schema(),
)
]
),
),
]
)
)
sink = InMemoryRunEventSink()
async def scenario() -> None:
async with httpx.AsyncClient() as client:
await AgentRunRunner(
sink=sink,
request=request,
run_id="run-tools",
plugin_daemon_http_client=client,
).run()
asyncio.run(scenario())
assert [tool.name for tool in seen_tools] == ["web_search"]
terminal = sink.events["run-tools"][-1]
assert isinstance(terminal, RunSucceededEvent)
assert terminal.data.output == "done"
def test_runner_rejects_duplicate_tool_names_across_dynamic_tool_layers(
monkeypatch: pytest.MonkeyPatch,
) -> None:
create_agent_called = False
async def duplicate_tool() -> str:
return "tool"
def fake_get_model(_self: DifyPluginLLMLayer, *, http_client: httpx.AsyncClient):
assert http_client.is_closed is False
return TestModel(custom_output_text="done") # pyright: ignore[reportReturnType]
async def fake_get_tools(_self: DifyPluginToolsLayer, *, http_client: httpx.AsyncClient) -> list[Tool[object]]:
assert http_client.is_closed is False
return [Tool(duplicate_tool, name="shared_tool")]
def fake_create_agent(model: object, *, tools: list[Tool[object]], output_type: object) -> object:
del model, tools, output_type
nonlocal create_agent_called
create_agent_called = True
raise AssertionError("create_agent should not be called when duplicate tool names are detected")
monkeypatch.setattr(DifyPluginLLMLayer, "get_model", fake_get_model)
monkeypatch.setattr(DifyPluginToolsLayer, "get_tools", fake_get_tools)
monkeypatch.setattr("dify_agent.runtime.runner.create_agent", fake_create_agent)
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(
name="prompt",
type="plain.prompt",
config=PromptLayerConfig(prefix="system", user="hello"),
),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type="dify.plugin.llm",
deps={"execution_context": "execution_context"},
config=DifyPluginLLMLayerConfig(
plugin_id="langgenius/openai",
model_provider="openai",
model="demo-model",
credentials={"api_key": "secret"},
),
),
RunLayerSpec(
name="tools-1",
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
config=DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/tools",
provider="search",
tool_name="web_search",
credential_type="api-key",
parameters=_prepared_plugin_tool_parameters(),
parameters_json_schema=_prepared_plugin_tool_schema(),
)
]
),
),
RunLayerSpec(
name="tools-2",
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
config=DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/tools",
provider="search",
tool_name="web_search_two",
credential_type="api-key",
parameters=_prepared_plugin_tool_parameters(),
parameters_json_schema=_prepared_plugin_tool_schema(),
)
]
),
),
]
)
)
sink = InMemoryRunEventSink()
async def scenario() -> None:
async with httpx.AsyncClient() as client:
with pytest.raises(
AgentRunValidationError,
match="unique tool names across all layers, got duplicates: shared_tool",
):
await AgentRunRunner(
sink=sink,
request=request,
run_id="run-duplicate-tools",
plugin_daemon_http_client=client,
).run()
asyncio.run(scenario())
assert create_agent_called is False
assert [event.type for event in sink.events["run-duplicate-tools"]] == ["run_started", "run_failed"]
assert sink.statuses["run-duplicate-tools"] == "failed"
def test_runner_rejects_duplicate_tool_names_between_static_and_dynamic_tools(
monkeypatch: pytest.MonkeyPatch,
) -> None:
create_agent_called = False
def web_search(query: str) -> str:
return query
async def dynamic_duplicate_tool() -> str:
return "tool"
def fake_get_model(_self: DifyPluginLLMLayer, *, http_client: httpx.AsyncClient):
assert http_client.is_closed is False
return TestModel(custom_output_text="done") # pyright: ignore[reportReturnType]
async def fake_get_tools(_self: DifyPluginToolsLayer, *, http_client: httpx.AsyncClient) -> list[Tool[object]]:
assert http_client.is_closed is False
return [Tool(dynamic_duplicate_tool, name="web_search")]
def fake_create_agent(model: object, *, tools: list[Tool[object]], output_type: object) -> object:
del model, tools, output_type
nonlocal create_agent_called
create_agent_called = True
raise AssertionError("create_agent should not be called when duplicate tool names are detected")
monkeypatch.setattr(DifyPluginLLMLayer, "get_model", fake_get_model)
monkeypatch.setattr(DifyPluginToolsLayer, "get_tools", fake_get_tools)
monkeypatch.setattr("dify_agent.runtime.runner.create_agent", fake_create_agent)
static_tools_provider = LayerProvider.from_factory(
layer_type=StaticToolsTestLayer,
create=lambda _config: StaticToolsTestLayer(tool_entries=(web_search,)),
)
layer_providers = (*create_default_layer_providers(), static_tools_provider)
request = CreateRunRequest(
composition=RunComposition(
layers=[
RunLayerSpec(
name="prompt",
type="plain.prompt",
config=PromptLayerConfig(prefix="system", user="hello"),
),
RunLayerSpec(name="static-tools", type=cast(str, StaticToolsTestLayer.type_id)),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type="dify.plugin.llm",
deps={"execution_context": "execution_context"},
config=DifyPluginLLMLayerConfig(
plugin_id="langgenius/openai",
model_provider="openai",
model="demo-model",
credentials={"api_key": "secret"},
),
),
RunLayerSpec(
name="tools",
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": "execution_context"},
config=DifyPluginToolsLayerConfig(
tools=[
DifyPluginToolConfig(
plugin_id="langgenius/tools",
provider="search",
tool_name="web_search",
credential_type="api-key",
parameters=_prepared_plugin_tool_parameters(),
parameters_json_schema=_prepared_plugin_tool_schema(),
)
]
),
),
]
)
)
sink = InMemoryRunEventSink()
async def scenario() -> None:
async with httpx.AsyncClient() as client:
with pytest.raises(
AgentRunValidationError,
match="unique tool names across all layers, got duplicates: web_search",
):
await AgentRunRunner(
sink=sink,
request=request,
run_id="run-static-dynamic-duplicate-tools",
plugin_daemon_http_client=client,
layer_providers=layer_providers,
).run()
asyncio.run(scenario())
assert create_agent_called is False
assert [event.type for event in sink.events["run-static-dynamic-duplicate-tools"]] == ["run_started", "run_failed"]
assert sink.statuses["run-static-dynamic-duplicate-tools"] == "failed"
def test_runner_passes_temporary_system_prompt_prefix_without_history_layer(monkeypatch: pytest.MonkeyPatch) -> None:
model = RecordingTestModel(custom_output_text="done")
@ -627,7 +271,7 @@ def test_runner_passes_temporary_system_prompt_prefix_without_history_layer(monk
assert isinstance(terminal, RunSucceededEvent)
assert [layer.name for layer in terminal.data.session_snapshot.layers] == [
"prompt",
"execution_context",
"plugin",
DIFY_AGENT_MODEL_LAYER_ID,
]
@ -796,7 +440,7 @@ def test_runner_applies_on_exit_overrides_to_success_snapshot(monkeypatch: pytes
assert isinstance(terminal, RunSucceededEvent)
assert {layer.name: layer.lifecycle_state for layer in terminal.data.session_snapshot.layers} == {
"prompt": LifecycleState.CLOSED,
"execution_context": LifecycleState.SUSPENDED,
"plugin": LifecycleState.SUSPENDED,
DIFY_AGENT_MODEL_LAYER_ID: LifecycleState.CLOSED,
}
@ -834,12 +478,7 @@ def test_runner_passes_output_layer_spec_to_agent_and_serializes_structured_resu
)
)
sink = InMemoryRunEventSink()
expected_snapshot_layer_names = [
"prompt",
"execution_context",
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
]
expected_snapshot_layer_names = ["prompt", "plugin", DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID]
async def scenario() -> None:
async with httpx.AsyncClient() as client:
@ -1043,16 +682,15 @@ def test_runner_rejects_misnamed_output_layer_before_model_resolution(monkeypatc
config=PromptLayerConfig(prefix="system", user="hello"),
),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"),
name="plugin",
type="dify.plugin",
config=DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai"),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type="dify.plugin.llm",
deps={"execution_context": "execution_context"},
deps={"plugin": "plugin"},
config=DifyPluginLLMLayerConfig(
plugin_id="langgenius/openai",
model_provider="openai",
model="demo-model",
credentials={"api_key": "secret"},
@ -1112,16 +750,15 @@ def test_runner_rejects_multiple_output_layers_before_model_resolution(monkeypat
config=PromptLayerConfig(prefix="system", user="hello"),
),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"),
name="plugin",
type="dify.plugin",
config=DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai"),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type="dify.plugin.llm",
deps={"execution_context": "execution_context"},
deps={"plugin": "plugin"},
config=DifyPluginLLMLayerConfig(
plugin_id="langgenius/openai",
model_provider="openai",
model="demo-model",
credentials={"api_key": "secret"},
@ -1203,16 +840,15 @@ def test_runner_rejects_reserved_output_name_with_wrong_layer_type_before_model_
config=PromptLayerConfig(prefix="system", user="hello"),
),
RunLayerSpec(
name="execution_context",
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
config=DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run"),
name="plugin",
type="dify.plugin",
config=DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="langgenius/openai"),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type="dify.plugin.llm",
deps={"execution_context": "execution_context"},
deps={"plugin": "plugin"},
config=DifyPluginLLMLayerConfig(
plugin_id="langgenius/openai",
model_provider="openai",
model="demo-model",
credentials={"api_key": "secret"},
@ -1406,7 +1042,7 @@ def test_runner_rejects_closed_session_snapshot_as_validation_error() -> None:
runtime_state={},
),
LayerSessionSnapshot(
name="execution_context",
name="plugin",
lifecycle_state=LifecycleState.NEW,
runtime_state={},
),

View File

@ -6,9 +6,9 @@ import pytest
from fastapi.testclient import TestClient
import dify_agent.server.app as app_module
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer
from dify_agent.runtime.compositor_factory import DifyAgentLayerProvider
from dify_agent.layers.dify_plugin.configs import DifyPluginLayerConfig
from dify_agent.layers.dify_plugin.plugin_layer import DifyPluginLayer
from dify_agent.server.app import create_app, create_plugin_daemon_http_client
from dify_agent.server.settings import ServerSettings
from dify_agent.storage.redis_run_store import RedisRunStore
@ -148,15 +148,11 @@ def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pyt
assert scheduler.shutdown_grace_seconds == 5
layer_providers = scheduler.layer_providers
assert isinstance(layer_providers, tuple)
execution_context_provider = next(
provider for provider in layer_providers if provider.type_id == "dify.execution_context"
)
execution_context_layer = execution_context_provider.create_layer(
DifyExecutionContextLayerConfig(tenant_id="tenant-1", invoke_from="workflow_run")
)
assert isinstance(execution_context_layer, DifyExecutionContextLayer)
assert execution_context_layer.daemon_url == "http://plugin-daemon"
assert execution_context_layer.daemon_api_key == "daemon-secret"
plugin_provider = next(provider for provider in layer_providers if provider.type_id == "dify.plugin")
plugin_layer = plugin_provider.create_layer(DifyPluginLayerConfig(tenant_id="tenant-1", plugin_id="plugin-1"))
assert isinstance(plugin_layer, DifyPluginLayer)
assert plugin_layer.daemon_url == "http://plugin-daemon"
assert plugin_layer.daemon_api_key == "daemon-secret"
http_client = scheduler.plugin_daemon_http_client
assert http_client is fake_http_client
assert http_client.is_closed is False

View File

@ -1,7 +1,7 @@
from fastapi.testclient import TestClient
from dify_agent.protocol import DIFY_AGENT_MODEL_LAYER_ID
from dify_agent.runtime.run_scheduler import SchedulerStoppingError
from dify_agent.runtime.run_scheduler import RunRequestValidationError, SchedulerStoppingError
from dify_agent.server.routes.runs import create_runs_router
from dify_agent.server.schemas import RunRecord
@ -9,14 +9,14 @@ from dify_agent.server.schemas import RunRecord
class FakeScheduler:
async def create_run(self, request: object) -> object:
del request
return RunRecord(run_id="run-1", status="running")
raise RunRequestValidationError("run.user_prompts must not be empty")
class FakeStore:
pass
def test_create_run_accepts_effectively_blank_user_prompt_list() -> None:
def test_create_run_rejects_effectively_blank_user_prompt_list() -> None:
from fastapi import FastAPI
app = FastAPI()
@ -35,8 +35,8 @@ def test_create_run_accepts_effectively_blank_user_prompt_list() -> None:
},
)
assert response.status_code == 202
assert response.json() == {"run_id": "run-1", "status": "running"}
assert response.status_code == 422
assert response.json()["detail"] == "run.user_prompts must not be empty"
def test_create_run_returns_running_from_scheduler() -> None:
@ -104,16 +104,15 @@ def test_create_run_accepts_valid_full_plugin_graph() -> None:
"layers": [
{"name": "prompt", "type": "plain.prompt", "config": {"user": "hello"}},
{
"name": "execution-context-renamed",
"type": "dify.execution_context",
"config": {"tenant_id": "tenant-1", "invoke_from": "workflow_run"},
"name": "plugin-renamed",
"type": "dify.plugin",
"config": {"tenant_id": "tenant-1", "plugin_id": "langgenius/openai"},
},
{
"name": DIFY_AGENT_MODEL_LAYER_ID,
"type": "dify.plugin.llm",
"deps": {"execution_context": "execution-context-renamed"},
"deps": {"plugin": "plugin-renamed"},
"config": {
"plugin_id": "langgenius/openai",
"model_provider": "openai",
"model": "gpt-4o-mini",
"credentials": {"api_key": "secret"},
@ -129,12 +128,17 @@ def test_create_run_accepts_valid_full_plugin_graph() -> None:
assert response.json() == {"run_id": "run-1", "status": "running"}
def test_create_run_accepts_unknown_layer_exit_signal_request() -> None:
def test_create_run_rejects_unknown_layer_exit_signal_before_scheduling() -> None:
from fastapi import FastAPI
class UnknownSignalScheduler:
async def create_run(self, request: object) -> RunRecord:
del request
raise RunRequestValidationError("on_exit.layers references unknown layer ids: missing.")
app = FastAPI()
app.include_router(
create_runs_router(lambda: FakeStore(), lambda: FakeScheduler()) # pyright: ignore[reportArgumentType]
create_runs_router(lambda: FakeStore(), lambda: UnknownSignalScheduler()) # pyright: ignore[reportArgumentType]
)
client = TestClient(app)
@ -149,16 +153,21 @@ def test_create_run_accepts_unknown_layer_exit_signal_request() -> None:
},
)
assert response.status_code == 202
assert response.json() == {"run_id": "run-1", "status": "running"}
assert response.status_code == 422
assert "missing" in response.json()["detail"]
def test_create_run_accepts_closed_session_snapshot_request() -> None:
def test_create_run_rejects_closed_session_snapshot_with_422() -> None:
from fastapi import FastAPI
class ClosedSnapshotScheduler:
async def create_run(self, request: object) -> RunRecord:
del request
raise RunRequestValidationError("Layer 'prompt' is closed; CLOSED snapshots cannot be entered.")
app = FastAPI()
app.include_router(
create_runs_router(lambda: FakeStore(), lambda: FakeScheduler()) # pyright: ignore[reportArgumentType]
create_runs_router(lambda: FakeStore(), lambda: ClosedSnapshotScheduler()) # pyright: ignore[reportArgumentType]
)
client = TestClient(app)
@ -182,8 +191,8 @@ def test_create_run_accepts_closed_session_snapshot_request() -> None:
},
)
assert response.status_code == 202
assert response.json() == {"run_id": "run-1", "status": "running"}
assert response.status_code == 422
assert "CLOSED snapshots cannot be entered" in response.json()["detail"]
def test_create_run_returns_503_when_scheduler_is_stopping() -> None:

View File

@ -79,9 +79,8 @@ def test_protocol_and_dify_plugin_exports_do_not_import_server_only_modules() ->
blocked_imports=[
"anthropic",
"dify_agent.adapters.llm",
"dify_agent.layers.execution_context.layer",
"dify_agent.layers.dify_plugin.llm_layer",
"dify_agent.layers.dify_plugin.tools_layer",
"dify_agent.layers.dify_plugin.plugin_layer",
"dify_agent.layers.output.output_layer",
"dify_agent.runtime",
"dify_agent.server",
@ -92,16 +91,10 @@ def test_protocol_and_dify_plugin_exports_do_not_import_server_only_modules() ->
"pydantic_settings",
"redis",
],
imports=[
"dify_agent.protocol",
"dify_agent.layers.execution_context",
"dify_agent.layers.dify_plugin",
"dify_agent.layers.output",
],
imports=["dify_agent.protocol", "dify_agent.layers.dify_plugin", "dify_agent.layers.output"],
assertions=[
"assert hasattr(dify_agent_protocol, 'PydanticAIStreamRunEvent')",
"assert dify_agent_layers_execution_context.__all__ == ['DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID', 'DifyExecutionContextInvokeFrom', 'DifyExecutionContextLayerConfig']",
"assert dify_agent_layers_dify_plugin.__all__ == ['DIFY_PLUGIN_LLM_LAYER_TYPE_ID', 'DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID', 'DifyPluginCredentialValue', 'DifyPluginLLMLayerConfig', 'DifyPluginToolCredentialType', 'DifyPluginToolConfig', 'DifyPluginToolOption', 'DifyPluginToolParameter', 'DifyPluginToolParameterForm', 'DifyPluginToolParameterType', 'DifyPluginToolsLayerConfig', 'DifyPluginToolValue']",
"assert dify_agent_layers_dify_plugin.__all__ == ['DIFY_PLUGIN_LAYER_TYPE_ID', 'DIFY_PLUGIN_LLM_LAYER_TYPE_ID', 'DifyPluginCredentialValue', 'DifyPluginLLMLayerConfig', 'DifyPluginLayerConfig']",
"assert dify_agent_layers_output.__all__ == ['DIFY_OUTPUT_LAYER_TYPE_ID', 'DifyOutputLayerConfig']",
],
)

View File

@ -1757,6 +1757,14 @@
"count": 1
}
},
"web/app/components/base/textarea/index.stories.tsx": {
"no-console": {
"count": 1
},
"ts/no-explicit-any": {
"count": 1
}
},
"web/app/components/base/voice-input/__tests__/index.spec.tsx": {
"ts/no-explicit-any": {
"count": 3

View File

@ -34,7 +34,6 @@ import { FieldControl, FieldLabel, FieldRoot } from '@langgenius/dify-ui/field'
import { Form } from '@langgenius/dify-ui/form'
import { Popover, PopoverContent, PopoverTrigger } from '@langgenius/dify-ui/popover'
import { SegmentedControl, SegmentedControlItem } from '@langgenius/dify-ui/segmented-control'
import { Textarea } from '@langgenius/dify-ui/textarea'
import '@langgenius/dify-ui/styles.css' // once, in the app root
```
@ -42,17 +41,17 @@ Importing from `@langgenius/dify-ui` (no subpath) is intentionally not supported
## Primitives
| Category | Subpath | Notes |
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------ |
| Actions | `./button` | Design-system CTA primitive with `cva` variants. |
| Controls | `./segmented-control` | SegmentedControl for mode, filter, and view selection. |
| Feedback | `./meter`, `./toast` | Meter is inline status; Toast owns the `z-60` layer. |
| Form | `./form`, `./field`, `./fieldset`, `./input`, `./textarea`, `./checkbox`, `./checkbox-group`, `./radio`, `./radio-group`, `./number-field`, `./select`, `./slider`, `./switch` | Native form boundary, field semantics, and controls. |
| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. |
| Media | `./avatar` | Avatar root, image, and fallback primitives. |
| Navigation | `./pagination`, `./tabs` | Pagination for page navigation; Tabs for panels. |
| Overlay / menu | `./alert-dialog`, `./context-menu`, `./dialog`, `./drawer`, `./dropdown-menu`, `./popover`, `./preview-card`, `./tooltip` | Portalled. See [Overlay & portal contract] below. |
| Search / pickers | `./autocomplete`, `./combobox`, `./select` | Search input, searchable picker, and closed picker. |
| Category | Subpath | Notes |
| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------ |
| Actions | `./button` | Design-system CTA primitive with `cva` variants. |
| Controls | `./segmented-control` | SegmentedControl for mode, filter, and view selection. |
| Feedback | `./meter`, `./toast` | Meter is inline status; Toast owns the `z-60` layer. |
| Form | `./form`, `./field`, `./fieldset`, `./input`, `./checkbox`, `./checkbox-group`, `./radio`, `./radio-group`, `./number-field`, `./select`, `./slider`, `./switch` | Native form boundary, field semantics, and controls. |
| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. |
| Media | `./avatar` | Avatar root, image, and fallback primitives. |
| Navigation | `./pagination`, `./tabs` | Pagination for page navigation; Tabs for panels. |
| Overlay / menu | `./alert-dialog`, `./context-menu`, `./dialog`, `./drawer`, `./dropdown-menu`, `./popover`, `./preview-card`, `./tooltip` | Portalled. See [Overlay & portal contract] below. |
| Search / pickers | `./autocomplete`, `./combobox`, `./select` | Search input, searchable picker, and closed picker. |
Utilities:
@ -73,7 +72,7 @@ Use `Form` for the submit boundary. It renders a native `<form>`, preserves Ente
Use `FieldRoot` for each standalone named field. A field must have a stable `name`, a label relationship, and either a `FieldControl` or another control that participates in the same Base UI field context. Prefer a visible label for normal form rows; when the surrounding UI already supplies the visible text, use the matching label primitive visually hidden or put `aria-label` on the actual interactive control. `FieldDescription` and `FieldError` provide the message relationships that screen readers need, while the Dify wrapper adds the default Form Input Set styling from the design system.
Choose the label primitive by the control semantics. Text-like inputs, `Textarea`, input-based `Combobox` / `Autocomplete`, single `Checkbox` / `Radio`, `Switch`, and `NumberField` use `FieldLabel`. Trigger-based `Select` fields use `SelectLabel`; `Slider` fields use `SliderLabel`, with per-thumb `aria-label` only when the thumbs need distinct names. `SelectGroupLabel` and `AutocompleteGroupLabel` only label grouped options inside their popup content; they are not field labels.
Choose the label primitive by the control semantics. Text-like inputs, input-based `Combobox` / `Autocomplete`, single `Checkbox` / `Radio`, `Switch`, and `NumberField` use `FieldLabel`. Trigger-based `Select` fields use `SelectLabel`; `Slider` fields use `SliderLabel`, with per-thumb `aria-label` only when the thumbs need distinct names. `SelectGroupLabel` and `AutocompleteGroupLabel` only label grouped options inside their popup content; they are not field labels.
Use `FieldsetRoot` and `FieldsetLegend` when one field is represented by a group of related controls, such as checkbox groups, radio groups, multi-thumb sliders, or a section that combines several inputs. For checkbox and radio groups, wrap each option with `FieldItem` and give each option its own label:

View File

@ -129,10 +129,6 @@
"types": "./src/tabs/index.tsx",
"import": "./src/tabs/index.tsx"
},
"./textarea": {
"types": "./src/textarea/index.tsx",
"import": "./src/textarea/index.tsx"
},
"./toast": {
"types": "./src/toast/index.tsx",
"import": "./src/toast/index.tsx"

View File

@ -1,187 +0,0 @@
import type { FocusEvent } from 'react'
import { render } from 'vitest-browser-react'
import {
FieldDescription,
FieldError,
FieldLabel,
FieldRoot,
} from '../../field'
import { Form } from '../../form'
import { Textarea } from '../index'
const asHTMLElement = (element: HTMLElement | SVGElement) => element as HTMLElement
const setTextareaValue = (element: HTMLElement | SVGElement, value: string) => {
const textarea = asHTMLElement(element) as HTMLTextAreaElement
const valueSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, 'value')?.set
valueSetter?.call(textarea, value)
textarea.dispatchEvent(new Event('input', { bubbles: true }))
}
describe('Textarea', () => {
it('should render a labelled textarea through Base UI Field.Control', async () => {
const screen = await render(
<FieldRoot name="description">
<FieldLabel>Description</FieldLabel>
<Textarea defaultValue="A workspace for support automation." />
<FieldDescription>Shown to workspace members.</FieldDescription>
</FieldRoot>,
)
const textarea = screen.getByRole('textbox', { name: 'Description' })
await expect.element(textarea).toHaveValue('A workspace for support automation.')
await expect.element(textarea).toHaveAccessibleDescription('Shown to workspace members.')
await expect.element(textarea).toHaveClass('min-h-20', 'overflow-auto', 'rounded-lg', 'system-sm-regular')
expect(asHTMLElement(textarea.element()).tagName).toBe('TEXTAREA')
})
it('should apply size variants and custom classes', async () => {
const screen = await render(
<label>
Prompt
<Textarea size="large" className="resize-none" />
</label>,
)
await expect.element(screen.getByRole('textbox', { name: 'Prompt' })).toHaveClass(
'rounded-[10px]',
'px-4',
'py-2',
'system-md-regular',
'resize-none',
)
})
it('should call onValueChange and stay controlled until value changes', async () => {
const onValueChange = vi.fn()
const screen = await render(
<label>
Notes
<Textarea value="" onValueChange={onValueChange} />
</label>,
)
const textarea = screen.getByRole('textbox', { name: 'Notes' })
setTextareaValue(textarea.element(), 'a')
expect(onValueChange).toHaveBeenCalledWith('a', expect.any(Object))
await expect.element(textarea).toHaveValue('')
await screen.rerender(
<label>
Notes
<Textarea value="a" onValueChange={onValueChange} />
</label>,
)
await expect.element(screen.getByRole('textbox', { name: 'Notes' })).toHaveValue('a')
})
it('should submit valid values and show validation errors through Base UI Form', async () => {
const onFormSubmit = vi.fn()
const screen = await render(
<Form aria-label="dataset form" onFormSubmit={onFormSubmit}>
<FieldRoot name="summary">
<FieldLabel>Summary</FieldLabel>
<Textarea required minLength={10} />
<FieldError match="valueMissing">Summary is required.</FieldError>
<FieldError match="tooShort">Summary is too short.</FieldError>
</FieldRoot>
<button type="submit">Save</button>
</Form>,
)
const saveButton = asHTMLElement(screen.getByRole('button', { name: 'Save' }).element())
saveButton.click()
await vi.waitFor(async () => {
await expect.element(screen.getByText('Summary is required.')).toBeInTheDocument()
await expect.element(screen.getByRole('textbox', { name: 'Summary' })).toHaveAttribute('aria-invalid', 'true')
})
expect(onFormSubmit).not.toHaveBeenCalled()
await screen.rerender(
<Form aria-label="dataset form" onFormSubmit={onFormSubmit}>
<FieldRoot name="summary">
<FieldLabel>Summary</FieldLabel>
<Textarea key="valid-summary" required minLength={10} defaultValue="Long enough summary" />
<FieldError match="valueMissing">Summary is required.</FieldError>
<FieldError match="tooShort">Summary is too short.</FieldError>
</FieldRoot>
<button type="submit">Save</button>
</Form>,
)
asHTMLElement(screen.getByRole('button', { name: 'Save' }).element()).click()
expect(onFormSubmit).toHaveBeenCalledTimes(1)
expect(onFormSubmit.mock.calls[0]?.[0]).toMatchObject({ summary: 'Long enough summary' })
})
it('should pass maxLength to the textarea without rendering a counter', async () => {
const screen = await render(
<label>
Release notes
<Textarea defaultValue="Draft" maxLength={20} />
</label>,
)
const textarea = screen.getByRole('textbox', { name: 'Release notes' })
await expect.element(textarea).toHaveAttribute('maxLength', '20')
expect(screen.container.textContent).not.toContain('5/20')
})
it('should route field props through Base UI Field.Control and textarea-only props to textarea', async () => {
const onFormSubmit = vi.fn()
const onBlur = vi.fn((event: FocusEvent<HTMLTextAreaElement>) => {
expect(event.currentTarget.tagName).toBe('TEXTAREA')
})
const screen = await render(
<Form aria-label="profile form" onFormSubmit={onFormSubmit}>
<FieldRoot name="profileSummary">
<FieldLabel>Profile summary</FieldLabel>
<Textarea
id="profile-summary"
name="ignoredControlName"
defaultValue="Long enough summary"
rows={6}
cols={40}
wrap="soft"
maxLength={80}
onBlur={onBlur}
/>
</FieldRoot>
<FieldRoot disabled>
<FieldLabel>Disabled note</FieldLabel>
<Textarea name="disabledNote" defaultValue="Disabled value" />
</FieldRoot>
<button type="submit">Save</button>
</Form>,
)
const profileSummary = screen.getByRole('textbox', { name: 'Profile summary' })
expect(
asHTMLElement(screen.getByText('Profile summary').element()).getAttribute('for'),
).toBe('profile-summary')
await expect.element(profileSummary).toHaveAttribute('id', 'profile-summary')
await expect.element(profileSummary).toHaveAttribute('name', 'profileSummary')
await expect.element(profileSummary).toHaveAttribute('rows', '6')
await expect.element(profileSummary).toHaveAttribute('cols', '40')
await expect.element(profileSummary).toHaveAttribute('wrap', 'soft')
await expect.element(profileSummary).toHaveAttribute('maxLength', '80')
await expect.element(screen.getByRole('textbox', { name: 'Disabled note' })).toBeDisabled()
asHTMLElement(profileSummary.element()).focus()
const saveButton = asHTMLElement(screen.getByRole('button', { name: 'Save' }).element())
saveButton.focus()
expect(onBlur).toHaveBeenCalledTimes(1)
saveButton.click()
expect(onFormSubmit).toHaveBeenCalledTimes(1)
expect(onFormSubmit.mock.calls[0]?.[0]).toMatchObject({
profileSummary: 'Long enough summary',
})
expect(onFormSubmit.mock.calls[0]?.[0]).not.toHaveProperty('ignoredControlName')
expect(onFormSubmit.mock.calls[0]?.[0]).not.toHaveProperty('disabledNote')
})
})

View File

@ -1,193 +0,0 @@
import type { Meta, StoryObj } from '@storybook/react-vite'
import { useState } from 'react'
import { Button } from '../button'
import {
FieldDescription,
FieldError,
FieldLabel,
FieldRoot,
} from '../field'
import { Form } from '../form'
import { Textarea } from './index'
const meta = {
title: 'Base/Form/Textarea',
component: Textarea,
parameters: {
layout: 'centered',
docs: {
description: {
component: 'Multiline text control built on Base UI Field.Control. Use it with FieldRoot for labelled, described, and validated form fields.',
},
},
},
tags: ['autodocs'],
} satisfies Meta<typeof Textarea>
export default meta
type Story = StoryObj<typeof meta>
export const Basic: Story = {
render: () => (
<div className="w-80">
<label htmlFor="workspace-description" className="mb-1 block w-fit py-1 text-text-secondary system-sm-medium">
Workspace description
</label>
<Textarea
id="workspace-description"
name="workspaceDescription"
placeholder="Describe how this workspace is used..."
/>
</div>
),
}
export const Sizes: Story = {
render: () => (
<div className="grid w-80 gap-3">
<label className="grid gap-1 text-text-secondary system-sm-medium" htmlFor="small-textarea">
Small
<Textarea id="small-textarea" size="small" name="smallTextarea" placeholder="Short note..." rows={3} />
</label>
<label className="grid gap-1 text-text-secondary system-sm-medium" htmlFor="medium-textarea">
Medium
<Textarea id="medium-textarea" name="mediumTextarea" placeholder="Add context..." rows={3} />
</label>
<label className="grid gap-1 text-text-secondary system-sm-medium" htmlFor="large-textarea">
Large
<Textarea id="large-textarea" size="large" name="largeTextarea" placeholder="Write a longer instruction..." rows={3} />
</label>
</div>
),
}
export const States: Story = {
render: () => (
<div className="grid w-80 gap-3">
<FieldRoot name="placeholderState">
<FieldLabel>Placeholder</FieldLabel>
<Textarea placeholder="Add a description..." rows={3} />
</FieldRoot>
<FieldRoot name="filledState">
<FieldLabel>Filled</FieldLabel>
<Textarea defaultValue="Use this dataset for support articles and product FAQs." rows={3} />
</FieldRoot>
<FieldRoot name="invalidState" invalid>
<FieldLabel>Invalid</FieldLabel>
<Textarea defaultValue="Too short" rows={3} />
<FieldError match>Use at least 20 characters.</FieldError>
</FieldRoot>
<FieldRoot name="disabledState">
<FieldLabel>Disabled</FieldLabel>
<Textarea disabled placeholder="Editing is unavailable..." rows={3} />
</FieldRoot>
<FieldRoot name="readonlyState">
<FieldLabel>Read-only</FieldLabel>
<Textarea readOnly defaultValue="Generated from the published workflow configuration." rows={3} />
</FieldRoot>
</div>
),
}
const FormDemo = () => {
const [savedDescription, setSavedDescription] = useState<string | null>(null)
return (
<Form
aria-label="Dataset settings"
className="grid w-80 gap-4"
onFormSubmit={(values) => {
setSavedDescription(String(values.description ?? ''))
}}
>
<FieldRoot name="description">
<FieldLabel>Description</FieldLabel>
<Textarea
required
minLength={20}
maxLength={160}
placeholder="Describe what this dataset contains..."
rows={4}
className="resize-y"
/>
<FieldDescription>Shown to teammates when they choose a knowledge source.</FieldDescription>
<FieldError match="valueMissing">Description is required.</FieldError>
<FieldError match="tooShort">Use at least 20 characters.</FieldError>
</FieldRoot>
<div className="flex justify-end">
<Button type="submit" variant="primary">Save Settings</Button>
</div>
{savedDescription && (
<div className="rounded-lg bg-background-section px-3 py-2 text-text-secondary system-xs-regular">
Saved:
{' '}
{savedDescription}
</div>
)}
</Form>
)
}
export const WithField: Story = {
render: () => <FormDemo />,
}
const ControlledDemo = () => {
const [value, setValue] = useState('Summarize customer feedback into actionable product themes.')
return (
<FieldRoot name="prompt">
<FieldLabel>Prompt</FieldLabel>
<Textarea
value={value}
onValueChange={nextValue => setValue(nextValue)}
rows={4}
className="resize-y"
/>
<FieldDescription>The saved value is updated from the controlled state.</FieldDescription>
</FieldRoot>
)
}
export const Controlled: Story = {
render: () => (
<div className="w-80">
<ControlledDemo />
</div>
),
}
const CharacterCounterDemo = () => {
const maxLength = 120
const [value, setValue] = useState('Summarize customer feedback into actionable product themes.')
return (
<FieldRoot name="limitedPrompt">
<FieldLabel>Prompt</FieldLabel>
<div className="relative">
<Textarea
value={value}
onValueChange={nextValue => setValue(nextValue)}
maxLength={maxLength}
rows={4}
className="resize-y pb-8"
/>
<div className="pointer-events-none absolute right-2 bottom-2 flex h-5 items-center rounded-md bg-background-section px-1 text-text-quaternary system-xs-medium">
<span>{value.length}</span>
/
<span className="text-text-tertiary">{maxLength}</span>
</div>
</div>
<FieldDescription>Character counters are composed at the usage site when the workflow needs one.</FieldDescription>
</FieldRoot>
)
}
export const WithCharacterCounter: Story = {
render: () => (
<div className="w-80">
<CharacterCounterDemo />
</div>
),
}

View File

@ -1,103 +0,0 @@
'use client'
import type { Field as BaseFieldNS } from '@base-ui/react/field'
import type { VariantProps } from 'class-variance-authority'
import type { ComponentPropsWithRef } from 'react'
import { Field as BaseField } from '@base-ui/react/field'
import { cva } from 'class-variance-authority'
import { cn } from '../cn'
const textareaVariants = cva(
[
'min-h-20 w-full appearance-none overflow-auto border border-transparent bg-components-input-bg-normal text-components-input-text-filled caret-primary-600 outline-hidden transition-[background-color,border-color,box-shadow]',
'placeholder:text-components-input-text-placeholder',
'hover:border-components-input-border-hover hover:bg-components-input-bg-hover',
'focus:border-components-input-border-active focus:bg-components-input-bg-active focus:shadow-xs',
'data-invalid:border-components-input-border-destructive data-invalid:bg-components-input-bg-destructive',
'read-only:cursor-default read-only:shadow-none read-only:hover:border-transparent read-only:hover:bg-components-input-bg-normal read-only:focus:border-transparent read-only:focus:bg-components-input-bg-normal read-only:focus:shadow-none',
'disabled:cursor-not-allowed disabled:border-transparent disabled:bg-components-input-bg-disabled disabled:text-components-input-text-filled-disabled',
'disabled:hover:border-transparent disabled:hover:bg-components-input-bg-disabled',
'motion-reduce:transition-none',
],
{
variants: {
size: {
small: 'rounded-md px-2 py-1 system-xs-regular',
medium: 'rounded-lg px-3 py-2 system-sm-regular',
large: 'rounded-[10px] px-4 py-2 system-md-regular',
},
},
defaultVariants: {
size: 'medium',
},
},
)
type TextareaValue = string | number
export type TextareaSize = NonNullable<VariantProps<typeof textareaVariants>['size']>
export type TextareaChangeEventDetails = BaseFieldNS.Control.ChangeEventDetails
type TextareaOnValueChange = (value: string, eventDetails: TextareaChangeEventDetails) => void
type ControlledTextareaProps = {
value: TextareaValue
defaultValue?: never
onValueChange: TextareaOnValueChange
}
type UncontrolledTextareaProps = {
value?: never
defaultValue?: TextareaValue
onValueChange?: TextareaOnValueChange
}
type TextareaNativeProps = ComponentPropsWithRef<'textarea'>
type TextareaOnlyProps = Pick<TextareaNativeProps, 'cols' | 'rows' | 'wrap'>
type TextareaElementProps = Omit<
TextareaNativeProps,
'children' | 'className' | 'cols' | 'defaultValue' | 'onChange' | 'rows' | 'size' | 'value' | 'wrap'
>
type TextareaControlProps = ControlledTextareaProps | UncontrolledTextareaProps
type TextareaVariantProps = VariantProps<typeof textareaVariants>
type FieldControlTextareaProps = Omit<
BaseFieldNS.Control.Props,
'className' | 'defaultValue' | 'onValueChange' | 'render' | 'value'
>
export type TextareaProps
= TextareaElementProps
& TextareaOnlyProps
& TextareaControlProps
& TextareaVariantProps
& {
children?: never
className?: string
}
export function Textarea({
className,
cols,
defaultValue,
onValueChange,
ref,
rows,
size = 'medium',
value,
wrap,
...controlProps
}: TextareaProps) {
// Base UI types Field.Control as an input even when render replaces it with a textarea.
const fieldControlProps = controlProps as FieldControlTextareaProps
return (
<BaseField.Control
{...fieldControlProps}
className={cn(textareaVariants({ size }), className)}
defaultValue={defaultValue}
onValueChange={onValueChange}
ref={ref}
render={<textarea cols={cols} rows={rows} wrap={wrap} />}
value={value}
/>
)
}

View File

@ -493,8 +493,8 @@ describe('Capacity Full Components Integration', () => {
expect(screen.getByText(/upgradeBtn\.encourageShort/i)).toBeInTheDocument()
// Should show usage/total fraction "5/5"
expect(screen.getByText(/5\/5/)).toBeInTheDocument()
// Should have an accessible meter rendered
expect(screen.getByRole('meter', { name: /usagePage\.buildApps/i })).toBeInTheDocument()
// Should have a meter rendered
expect(screen.getByRole('meter')).toBeInTheDocument()
})
it('should display upgrade tip and upgrade button for professional plan', () => {

View File

@ -1,10 +1,10 @@
'use client'
import { Button } from '@langgenius/dify-ui/button'
import { Dialog, DialogContent, DialogTitle } from '@langgenius/dify-ui/dialog'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { toast } from '@langgenius/dify-ui/toast'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Textarea from '@/app/components/base/textarea'
import { useAppContext } from '@/context/app-context'
import { useRouter } from '@/next/navigation'
import { useLogout } from '@/service/use-common'
@ -63,12 +63,11 @@ export default function FeedBack(props: DeleteAccountProps) {
</DialogTitle>
<label className="mt-3 mb-1 flex items-center system-sm-semibold text-text-secondary">{t('account.feedbackLabel', { ns: 'common' })}</label>
<Textarea
aria-label={t('account.feedbackLabel', { ns: 'common' }) as string}
rows={6}
value={userFeedback}
placeholder={t('account.feedbackPlaceholder', { ns: 'common' }) as string}
onValueChange={(value) => {
setUserFeedback(value)
onChange={(e) => {
setUserFeedback(e.target.value)
}}
/>
<div className="mt-3 flex w-full flex-col gap-2">

View File

@ -1,9 +1,9 @@
'use client'
import type { FC } from 'react'
import { Textarea } from '@langgenius/dify-ui/textarea'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { Robot, User } from '@/app/components/base/icons/src/public/avatar'
import Textarea from '@/app/components/base/textarea'
export enum EditItemType {
Query = 'query',
@ -33,9 +33,8 @@ const EditItem: FC<Props> = ({
<div className="grow">
<div className="mb-1 system-xs-semibold text-text-primary">{name}</div>
<Textarea
aria-label={name}
value={content}
onValueChange={value => onChange(value)}
onChange={(e: React.ChangeEvent<HTMLTextAreaElement>) => onChange(e.target.value)}
placeholder={placeholder}
autoFocus
/>

View File

@ -2,12 +2,12 @@
import type { FC } from 'react'
import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { RiDeleteBinLine, RiEditFill, RiEditLine } from '@remixicon/react'
import * as React from 'react'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Robot, User } from '@/app/components/base/icons/src/public/avatar'
import Textarea from '@/app/components/base/textarea'
export enum EditItemType {
Query = 'query',
@ -130,9 +130,8 @@ const EditItem: FC<Props> = ({
<div className="mt-3">
<EditTitle title={editTitle} />
<Textarea
aria-label={editTitle}
value={newContent}
onValueChange={value => setNewContent(value)}
onChange={(e: React.ChangeEvent<HTMLTextAreaElement>) => setNewContent(e.target.value)}
placeholder={placeholder}
autoFocus
/>

View File

@ -3,12 +3,12 @@ import type { VersionHistory } from '@/types/workflow'
import { Button } from '@langgenius/dify-ui/button'
import { Dialog, DialogContent } from '@langgenius/dify-ui/dialog'
import { FieldControl, FieldLabel, FieldRoot } from '@langgenius/dify-ui/field'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { toast } from '@langgenius/dify-ui/toast'
import { RiCloseLine } from '@remixicon/react'
import * as React from 'react'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Textarea from '../../base/textarea'
type VersionInfoModalProps = {
isOpen: boolean
@ -57,8 +57,8 @@ const VersionInfoModal: FC<VersionInfoModalProps> = ({
onClose()
}
const handleDescriptionChange = useCallback((value: string) => {
setReleaseNotes(value)
const handleDescriptionChange = useCallback((e: React.ChangeEvent<HTMLTextAreaElement>) => {
setReleaseNotes(e.target.value)
}, [])
return (
@ -95,16 +95,17 @@ const VersionInfoModal: FC<VersionInfoModalProps> = ({
onValueChange={setTitle}
/>
</FieldRoot>
<FieldRoot name="releaseNotes" invalid={releaseNotesError} className="gap-y-1">
<FieldLabel className="flex h-6 items-center py-0 system-sm-semibold text-text-secondary">
<div className="flex flex-col gap-y-1">
<div className="flex h-6 items-center system-sm-semibold text-text-secondary">
{t('versionHistory.editField.releaseNotes', { ns: 'workflow' })}
</FieldLabel>
</div>
<Textarea
value={releaseNotes}
placeholder={`${t('versionHistory.releaseNotesPlaceholder', { ns: 'workflow' })}${t('panel.optional', { ns: 'workflow' })}`}
onValueChange={handleDescriptionChange}
onChange={handleDescriptionChange}
destructive={releaseNotesError}
/>
</FieldRoot>
</div>
</div>
<div className="flex justify-end p-6 pt-5">
<div className="flex items-center gap-x-3">

View File

@ -13,12 +13,12 @@ import {
SelectTrigger,
SelectValue,
} from '@langgenius/dify-ui/select'
import { Textarea } from '@langgenius/dify-ui/textarea'
import * as React from 'react'
import { Trans } from 'react-i18next'
import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader'
import { Infotip } from '@/app/components/base/infotip'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
import FileUploadSetting from '@/app/components/workflow/nodes/_base/components/file-upload-setting'
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
@ -121,9 +121,8 @@ const ConfigModalFormFields: FC<ConfigModalFormFieldsProps> = ({
{type === InputVarType.paragraph && (
<Field title={t('variableConfig.defaultValue', { ns: 'appDebug' })}>
<Textarea
aria-label={t('variableConfig.defaultValue', { ns: 'appDebug' })}
value={String(tempPayload.default ?? '')}
onValueChange={value => onPayloadChange('default')(value || undefined)}
onChange={e => onPayloadChange('default')(e.target.value || undefined)}
placeholder={t('variableConfig.inputPlaceholder', { ns: 'appDebug' })}
/>
</Field>

View File

@ -1,11 +1,11 @@
'use client'
import type { FC } from 'react'
import { cn } from '@langgenius/dify-ui/cn'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { useBoolean } from 'ahooks'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general'
import Textarea from '@/app/components/base/textarea'
const i18nPrefix = 'generate'
@ -40,11 +40,10 @@ const IdeaOutput: FC<Props> = ({
</div>
{!isFoldIdeaOutput && (
<Textarea
aria-label={t(`${i18nPrefix}.idealOutput`, { ns: 'appDebug' })}
className="h-[80px]"
placeholder={t(`${i18nPrefix}.idealOutputPlaceholder`, { ns: 'appDebug' })}
value={value}
onValueChange={value => onChange(value)}
onChange={e => onChange(e.target.value)}
/>
)}
</div>

View File

@ -4,13 +4,13 @@ import type { DataSet } from '@/models/datasets'
import type { RetrievalConfig } from '@/types/app'
import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { toast } from '@langgenius/dify-ui/toast'
import { RiCloseLine } from '@remixicon/react'
import { isEqual } from 'es-toolkit/predicate'
import { useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { IndexingType } from '@/app/components/datasets/create/step-two'
import IndexMethod from '@/app/components/datasets/settings/index-method'
@ -224,9 +224,8 @@ const SettingsModal: FC<SettingsModalProps> = ({
</div>
<div className="w-full">
<Textarea
aria-label={t('form.desc', { ns: 'datasetSettings' })}
value={localeCurrentDataset.description || ''}
onValueChange={value => handleValueChange('description', value)}
onChange={e => handleValueChange('description', e.target.value)}
className="resize-none"
placeholder={t('form.descPlaceholder', { ns: 'datasetSettings' }) || ''}
/>

View File

@ -84,6 +84,25 @@ vi.mock('@langgenius/dify-ui/select', async () => {
}
})
vi.mock('@/app/components/base/textarea', () => ({
default: ({ value, onChange, placeholder, readOnly, className }: {
value: string
onChange: (e: { target: { value: string } }) => void
placeholder?: string
readOnly?: boolean
className?: string
}) => (
<textarea
data-testid={`textarea-${placeholder}`}
value={value}
onChange={onChange}
placeholder={placeholder}
readOnly={readOnly}
className={className}
/>
),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/before-run-form/bool-input', () => ({
default: ({ name, value, required, onChange, readonly }: {
name: string
@ -204,7 +223,7 @@ describe('ChatUserInput', () => {
}))
render(<ChatUserInput inputs={{}} />)
expect(screen.getByRole('textbox', { name: 'Description' })).toBeInTheDocument()
expect(screen.getByTestId('textarea-Description')).toBeInTheDocument()
})
it('should render select input type', () => {
@ -256,7 +275,7 @@ describe('ChatUserInput', () => {
render(<ChatUserInput inputs={{}} />)
expect(screen.getByTestId('input-Name')).toBeInTheDocument()
expect(screen.getByRole('textbox', { name: 'Description' })).toBeInTheDocument()
expect(screen.getByTestId('textarea-Description')).toBeInTheDocument()
expect(screen.getByTestId('select-input')).toBeInTheDocument()
})
@ -315,7 +334,7 @@ describe('ChatUserInput', () => {
}))
render(<ChatUserInput inputs={{ desc: 'Long text here' }} />)
expect(screen.getByRole('textbox', { name: 'Description' })).toHaveValue('Long text here')
expect(screen.getByTestId('textarea-Description')).toHaveValue('Long text here')
})
it('should display existing input values for number type', () => {
@ -399,7 +418,7 @@ describe('ChatUserInput', () => {
}))
render(<ChatUserInput inputs={{}} />)
fireEvent.change(screen.getByRole('textbox', { name: 'Description' }), { target: { value: 'New Description' } })
fireEvent.change(screen.getByTestId('textarea-Description'), { target: { value: 'New Description' } })
expect(mockSetInputs).toHaveBeenCalledWith({ desc: 'New Description' })
})
@ -507,7 +526,7 @@ describe('ChatUserInput', () => {
}))
render(<ChatUserInput inputs={{}} />)
expect(screen.getByRole('textbox', { name: 'Description' })).toHaveAttribute('readonly')
expect(screen.getByTestId('textarea-Description')).toHaveAttribute('readonly')
})
it('should disable select when readonly is true', () => {

View File

@ -1,12 +1,12 @@
import type { Inputs } from '@/models/debug'
import { cn } from '@langgenius/dify-ui/cn'
import { Select, SelectContent, SelectItem, SelectItemIndicator, SelectItemText, SelectTrigger } from '@langgenius/dify-ui/select'
import { Textarea } from '@langgenius/dify-ui/textarea'
import * as React from 'react'
import { useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input'
import ConfigContext from '@/context/debug-configuration'
@ -94,10 +94,9 @@ const ChatUserInput = ({
{type === 'paragraph' && (
<Textarea
className="h-[120px] grow"
aria-label={name || key}
placeholder={name}
value={inputs[key] ? `${inputs[key]}` : ''}
onValueChange={(value) => { handleInputValueChange(key, value) }}
onChange={(e) => { handleInputValueChange(key, e.target.value) }}
readOnly={readonly}
/>
)}

View File

@ -5,7 +5,6 @@ import type { VisionFile, VisionSettings } from '@/types/app'
import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { Select, SelectContent, SelectItem, SelectItemIndicator, SelectItemText, SelectTrigger } from '@langgenius/dify-ui/select'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { Tooltip, TooltipContent, TooltipTrigger } from '@langgenius/dify-ui/tooltip'
import {
RiArrowDownSLine,
@ -20,6 +19,7 @@ import { useStore as useAppStore } from '@/app/components/app/store'
import FeatureBar from '@/app/components/base/features/new-feature-panel/feature-bar'
import TextGenerationImageUploader from '@/app/components/base/image-uploader/text-generation-image-uploader'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input'
import ConfigContext from '@/context/debug-configuration'
import { AppModeEnum, ModelModeType } from '@/types/app'
@ -151,11 +151,10 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
)}
{type === 'paragraph' && (
<Textarea
aria-label={name}
className="h-[120px] grow"
placeholder={name}
value={inputs[key] ? `${inputs[key]}` : ''}
onValueChange={(value) => { handleInputValueChange(key, value) }}
onChange={(e) => { handleInputValueChange(key, e.target.value) }}
readOnly={readonly}
/>
)}

View File

@ -4,7 +4,6 @@ import type { AppIconSelection } from '../../base/app-icon-picker'
import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { toast } from '@langgenius/dify-ui/toast'
import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon/react'
import { useDebounceFn, useKeyPress } from 'ahooks'
@ -14,6 +13,7 @@ import AppIcon from '@/app/components/base/app-icon'
import Divider from '@/app/components/base/divider'
import { BubbleTextMod, ChatBot, ListSparkle, Logic } from '@/app/components/base/icons/src/vender/solid/communication'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import AppsFull from '@/app/components/billing/apps-full-in-dialog'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { useAppContext } from '@/context/app-context'
@ -241,11 +241,10 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
</span>
</div>
<Textarea
aria-label={t('newApp.captionDescription', { ns: 'app' })}
className="resize-none"
placeholder={t('newApp.appDescriptionPlaceholder', { ns: 'app' }) || ''}
value={description}
onValueChange={value => setDescription(value)}
onChange={e => setDescription(e.target.value)}
/>
</div>
</div>

View File

@ -8,7 +8,6 @@ import { cn } from '@langgenius/dify-ui/cn'
import { Dialog, DialogCloseButton, DialogContent, DialogTitle } from '@langgenius/dify-ui/dialog'
import { Select, SelectContent, SelectItem, SelectItemIndicator, SelectItemText, SelectTrigger } from '@langgenius/dify-ui/select'
import { Switch } from '@langgenius/dify-ui/switch'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { toast } from '@langgenius/dify-ui/toast'
import { Tooltip, TooltipContent, TooltipTrigger } from '@langgenius/dify-ui/tooltip'
import * as React from 'react'
@ -19,6 +18,7 @@ import AppIconPicker from '@/app/components/base/app-icon-picker'
import Divider from '@/app/components/base/divider'
import Input from '@/app/components/base/input'
import { PremiumBadgeButton } from '@/app/components/base/premium-badge'
import Textarea from '@/app/components/base/textarea'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import { useModalContext } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
@ -289,10 +289,9 @@ const SettingsModal: FC<ISettingsModalProps> = ({
<div className="relative">
<div className={cn('py-1 system-sm-semibold text-text-secondary')}>{t(`${prefixSettings}.webDesc`, { ns: 'appOverview' })}</div>
<Textarea
aria-label={t(`${prefixSettings}.webDesc`, { ns: 'appOverview' })}
className="mt-1"
value={inputInfo.desc}
onValueChange={onDesChange}
onChange={e => onDesChange(e.target.value)}
placeholder={t(`${prefixSettings}.webDescPlaceholder`, { ns: 'appOverview' }) as string}
/>
<p className={cn('pb-0.5 body-xs-regular text-text-tertiary')}>{t(`${prefixSettings}.webDescTip`, { ns: 'appOverview' })}</p>
@ -465,10 +464,9 @@ const SettingsModal: FC<ISettingsModalProps> = ({
<div className={cn('py-1 system-sm-semibold text-text-secondary')}>{t(`${prefixSettings}.more.customDisclaimer`, { ns: 'appOverview' })}</div>
<p className={cn('pb-0.5 body-xs-regular text-text-tertiary')}>{t(`${prefixSettings}.more.customDisclaimerTip`, { ns: 'appOverview' })}</p>
<Textarea
aria-label={t(`${prefixSettings}.more.customDisclaimer`, { ns: 'appOverview' })}
className="mt-1"
value={inputInfo.customDisclaimer}
onValueChange={value => setInputInfo(item => ({ ...item, customDisclaimer: value }))}
onChange={onChange('customDisclaimer')}
placeholder={t(`${prefixSettings}.more.customDisclaimerPlaceholder`, { ns: 'appOverview' }) as string}
/>
</div>

View File

@ -7,8 +7,8 @@ import {
SelectTrigger,
SelectValue,
} from '@langgenius/dify-ui/select'
import { Textarea } from '@langgenius/dify-ui/textarea'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import { InputVarType } from '@/app/components/workflow/types'
@ -74,7 +74,7 @@ const WorkflowHiddenInputFields = ({
<Textarea
id={fieldId}
value={typeof fieldValue === 'string' ? fieldValue : ''}
onValueChange={value => onValueChange(variable.variable, value)}
onChange={(event: ChangeEvent<HTMLTextAreaElement>) => onValueChange(variable.variable, event.target.value)}
placeholder={label}
maxLength={variable.max_length}
className="min-h-24"

View File

@ -1,10 +1,10 @@
import { Select, SelectContent, SelectItem, SelectItemIndicator, SelectItemText, SelectTrigger } from '@langgenius/dify-ui/select'
import { Textarea } from '@langgenius/dify-ui/textarea'
import * as React from 'react'
import { memo, useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input'
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
@ -71,9 +71,8 @@ const InputsFormContent = ({ showTip }: Props) => {
)}
{form.type === InputVarType.paragraph && (
<Textarea
aria-label={form.label}
value={inputsFormValue?.[form.variable] || ''}
onValueChange={value => handleFormChange(form.variable, value)}
onChange={e => handleFormChange(form.variable, e.target.value)}
placeholder={form.label}
/>
)}

View File

@ -1,8 +1,8 @@
import type { ContentItemProps } from './type'
import { Textarea } from '@langgenius/dify-ui/textarea'
import * as React from 'react'
import { useMemo } from 'react'
import { Markdown } from '@/app/components/base/markdown'
import Textarea from '@/app/components/base/textarea'
const ContentItem = ({
content,
@ -42,10 +42,9 @@ const ContentItem = ({
<div className="py-3">
{formInputField.type === 'paragraph' && (
<Textarea
aria-label={fieldName}
className="h-[104px] sm:text-xs"
value={inputs[fieldName]!}
onValueChange={(value) => { onInputChange(fieldName, value) }}
onChange={(e) => { onInputChange(fieldName, e.target.value) }}
data-testid="content-item-textarea"
/>
)}

View File

@ -6,7 +6,6 @@ import type {
import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { Dialog, DialogCloseButton, DialogContent, DialogDescription, DialogTitle } from '@langgenius/dify-ui/dialog'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { toast } from '@langgenius/dify-ui/toast'
import { Tooltip, TooltipContent, TooltipTrigger } from '@langgenius/dify-ui/tooltip'
import copy from 'copy-to-clipboard'
@ -22,6 +21,7 @@ import ActionButton, { ActionButtonState } from '@/app/components/base/action-bu
import Log from '@/app/components/base/chat/chat/log'
import AnnotationCtrlButton from '@/app/components/base/features/new-feature-panel/annotation-reply/annotation-ctrl-button'
import NewAudioButton from '@/app/components/base/new-audio-button'
import Textarea from '@/app/components/base/textarea'
import { useChatContext } from '../context'
type OperationProps = {
@ -394,7 +394,7 @@ function Operation({
id={feedbackTextareaId}
name="feedback-content"
value={feedbackContent}
onValueChange={value => setFeedbackContent(value)}
onChange={e => setFeedbackContent(e.target.value)}
placeholder={t('feedback.placeholder', { ns: 'common' }) || 'Please describe what went wrong or how we can improve…'}
rows={4}
className="w-full"

View File

@ -1,10 +1,10 @@
import { Select, SelectContent, SelectItem, SelectItemIndicator, SelectItemText, SelectTrigger } from '@langgenius/dify-ui/select'
import { Textarea } from '@langgenius/dify-ui/textarea'
import * as React from 'react'
import { memo, useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input'
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
@ -71,9 +71,8 @@ const InputsFormContent = ({ showTip }: Props) => {
)}
{form.type === InputVarType.paragraph && (
<Textarea
aria-label={form.label}
value={inputsFormValue?.[form.variable] || ''}
onValueChange={value => handleFormChange(form.variable, value)}
onChange={e => handleFormChange(form.variable, e.target.value)}
placeholder={form.label}
/>
)}

View File

@ -12,10 +12,10 @@ import { FieldItem, FieldRoot } from '@langgenius/dify-ui/field'
import { FieldsetLegend, FieldsetRoot } from '@langgenius/dify-ui/fieldset'
import { RadioControl, RadioRoot } from '@langgenius/dify-ui/radio'
import { RadioGroup } from '@langgenius/dify-ui/radio-group'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { produce } from 'immer'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Textarea from '@/app/components/base/textarea'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
@ -220,10 +220,9 @@ const FollowUpSettingModal = ({
</div>
{promptMode === PROMPT_MODE.custom && (
<Textarea
aria-label={t('feature.suggestedQuestionsAfterAnswer.modal.customPromptOption', { ns: 'appDebug' })}
className="mt-3 min-h-32 resize-y border-components-input-border-active bg-components-input-bg-normal"
value={prompt}
onValueChange={value => setPrompt(value)}
onChange={e => setPrompt(e.target.value)}
maxLength={CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH}
placeholder={t('feature.suggestedQuestionsAfterAnswer.modal.promptPlaceholder', { ns: 'appDebug' }) || ''}
/>

View File

@ -2,7 +2,7 @@ import type { FC } from 'react'
import type { CodeBasedExtensionForm } from '@/models/common'
import type { ModerationConfig } from '@/models/debug'
import { Select, SelectContent, SelectItem, SelectItemIndicator, SelectItemText, SelectTrigger } from '@langgenius/dify-ui/select'
import { Textarea } from '@langgenius/dify-ui/textarea'
import Textarea from '@/app/components/base/textarea'
import { useLocale } from '@/context/i18n'
type FormGenerationProps = {
@ -55,11 +55,10 @@ const FormGeneration: FC<FormGenerationProps> = ({
form.type === 'paragraph' && (
<div className="relative">
<Textarea
aria-label={locale === 'zh-Hans' ? form.label['zh-Hans'] : form.label['en-US']}
className="resize-none"
value={value?.[form.variable] || ''}
placeholder={form.placeholder}
onValueChange={value => handleFormChange(form.variable, value)}
onChange={e => handleFormChange(form.variable, e.target.value)}
/>
</div>
)

View File

@ -1,7 +1,6 @@
import type { FC } from 'react'
import type { ModerationContentConfig } from '@/models/debug'
import { Switch } from '@langgenius/dify-ui/switch'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { useTranslation } from 'react-i18next'
type ModerationContentProps = {
@ -51,14 +50,12 @@ const ModerationContent: FC<ModerationContentProps> = ({
{t('feature.moderation.modal.content.preset', { ns: 'appDebug' })}
<span className="text-xs font-normal text-text-tertiary">{t('feature.moderation.modal.content.supportMarkdown', { ns: 'appDebug' })}</span>
</div>
{/* Keep this counter composed locally; extract only if more textarea counter cases repeat. */}
<div className="relative h-20">
<Textarea
aria-label={t('feature.moderation.modal.content.preset', { ns: 'appDebug' }) as string}
<div className="relative h-20 rounded-lg bg-components-input-bg-normal px-3 py-2">
<textarea
value={config.preset_response || ''}
className="size-full resize-none pb-8"
className="block size-full resize-none appearance-none bg-transparent text-sm text-text-secondary outline-hidden"
placeholder={t('feature.moderation.modal.content.placeholder', { ns: 'appDebug' }) || ''}
onValueChange={value => handleConfigChange('preset_response', value)}
onChange={e => handleConfigChange('preset_response', e.target.value)}
/>
<div className="absolute right-2 bottom-2 flex h-5 items-center rounded-md bg-background-section px-1 text-xs font-medium text-text-quaternary">
<span>{(config.preset_response || '').length}</span>

View File

@ -1,10 +1,9 @@
import type { FC } from 'react'
import type { ChangeEvent, FC } from 'react'
import type { CodeBasedExtensionItem } from '@/models/common'
import type { ModerationConfig, ModerationContentConfig } from '@/models/debug'
import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { Dialog, DialogContent } from '@langgenius/dify-ui/dialog'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { toast } from '@langgenius/dify-ui/toast'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
@ -104,7 +103,9 @@ const ModerationSettingModal: FC<ModerationSettingModalProps> = ({
})
}
const handleDataKeywordsChange = (value: string) => {
const handleDataKeywordsChange = (e: ChangeEvent<HTMLTextAreaElement>) => {
const value = e.target.value
const arr = value.split('\n').reduce((prev: string[], next: string) => {
if (next !== '')
prev.push(next.slice(0, 100))
@ -291,13 +292,11 @@ const ModerationSettingModal: FC<ModerationSettingModalProps> = ({
<div className="py-2">
<div className="mb-1 text-sm font-medium text-text-primary">{t('feature.moderation.modal.provider.keywords', { ns: 'appDebug' })}</div>
<div className="mb-2 text-xs text-text-tertiary">{t('feature.moderation.modal.keywords.tip', { ns: 'appDebug' })}</div>
{/* Keep this counter composed locally; extract only if more textarea counter cases repeat. */}
<div className="relative h-[88px]">
<Textarea
aria-label={t('feature.moderation.modal.provider.keywords', { ns: 'appDebug' }) as string}
<div className="relative h-[88px] rounded-lg bg-components-input-bg-normal px-3 py-2">
<textarea
value={localeData.config?.keywords || ''}
onValueChange={handleDataKeywordsChange}
className="size-full resize-none pb-8"
onChange={handleDataKeywordsChange}
className="block size-full resize-none appearance-none bg-transparent text-sm text-text-secondary outline-hidden"
placeholder={t('feature.moderation.modal.keywords.placeholder', { ns: 'appDebug' }) || ''}
/>
<div className="absolute right-2 bottom-2 flex h-5 items-center rounded-md bg-background-section px-1 text-xs font-medium text-text-quaternary">

View File

@ -1,4 +1,3 @@
import type { ComponentProps } from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import TextAreaField from '../text-area'
@ -31,20 +30,4 @@ describe('TextAreaField', () => {
fireEvent.change(screen.getByLabelText('Note'), { target: { value: 'Updated note' } })
expect(mockField.handleChange).toHaveBeenCalledWith('Updated note')
})
it('should keep form writeback when external props contain onValueChange', () => {
const externalOnValueChange = vi.fn()
render(
<TextAreaField
label="Note"
{...({ onValueChange: externalOnValueChange } as Partial<ComponentProps<typeof TextAreaField>>)}
/>,
)
fireEvent.change(screen.getByLabelText('Note'), { target: { value: 'Updated note' } })
expect(mockField.handleChange).toHaveBeenCalledWith('Updated note')
expect(externalOnValueChange).not.toHaveBeenCalled()
})
})

View File

@ -1,16 +1,16 @@
import type { TextareaProps } from '@langgenius/dify-ui/textarea'
import type { TextareaProps } from '../../../textarea'
import type { LabelProps } from '../label'
import { cn } from '@langgenius/dify-ui/cn'
import { Textarea } from '@langgenius/dify-ui/textarea'
import * as React from 'react'
import { useFieldContext } from '../..'
import Textarea from '../../../textarea'
import Label from '../label'
type TextAreaFieldProps = {
label: string
labelOptions?: Omit<LabelProps, 'htmlFor' | 'label'>
className?: string
} & Omit<TextareaProps, 'className' | 'defaultValue' | 'onBlur' | 'onValueChange' | 'value' | 'id'>
} & Omit<TextareaProps, 'className' | 'onChange' | 'onBlur' | 'value' | 'id'>
const TextAreaField = ({
label,
@ -28,11 +28,11 @@ const TextAreaField = ({
{...(labelOptions ?? {})}
/>
<Textarea
{...inputProps}
id={field.name}
value={field.state.value}
onValueChange={value => field.handleChange(value)}
onChange={e => field.handleChange(e.target.value)}
onBlur={field.handleBlur}
{...inputProps}
/>
</div>
)

View File

@ -3,7 +3,6 @@ import type { Dayjs } from 'dayjs'
import { Button } from '@langgenius/dify-ui/button'
import { Checkbox } from '@langgenius/dify-ui/checkbox'
import { Select, SelectContent, SelectItem, SelectItemIndicator, SelectItemText, SelectTrigger, SelectValue } from '@langgenius/dify-ui/select'
import { Textarea } from '@langgenius/dify-ui/textarea'
import * as React from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useChatContext } from '@/app/components/base/chat/chat/context'
@ -11,6 +10,7 @@ import DatePicker from '@/app/components/base/date-and-time-picker/date-picker'
import TimePicker from '@/app/components/base/date-and-time-picker/time-picker'
import { formatDateForOutput, toDayjs } from '@/app/components/base/date-and-time-picker/utils/dayjs'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
const DATA_FORMAT = {
TEXT: 'text',
@ -372,12 +372,11 @@ const MarkdownForm = ({ node }: { node: HastElement }) => {
return null
return (
<Textarea
aria-label={name}
key={key}
name={name}
placeholder={str(child.properties.placeholder)}
value={str(formValues[name])}
onValueChange={value => updateValue(name, value)}
onChange={e => updateValue(name, e.target.value)}
/>
)
}

View File

@ -2,12 +2,12 @@
import type { FC } from 'react'
import type { ValueSelector, Var } from '@/app/components/workflow/types'
import { cn } from '@langgenius/dify-ui/cn'
import { Textarea } from '@langgenius/dify-ui/textarea'
import * as React from 'react'
import { useCallback, useEffect, useState } from 'react'
import { Trans, useTranslation } from 'react-i18next'
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
import { VarType } from '@/app/components/workflow/types'
import Textarea from '../../../textarea'
import TagLabel from './tag-label'
import TypeSwitch from './type-switch'
@ -72,7 +72,6 @@ const PrePopulate: FC<Props> = ({
value,
onValueChange,
}) => {
const { t } = useTranslation()
const [onPlaceholderClicked, setOnPlaceholderClicked] = useState(false)
const handleTypeChange = useCallback((isVar: boolean) => {
setOnPlaceholderClicked(true)
@ -128,10 +127,9 @@ const PrePopulate: FC<Props> = ({
return (
<div className={cn('relative min-h-[80px] rounded-lg border border-transparent bg-components-input-bg-normal pb-1', isFocus && 'border-components-input-border-active bg-components-input-bg-active shadow-xs')}>
<Textarea
aria-label={t(`${i18nPrefix}.staticContent`, { ns: 'workflow' })}
value={value || ''}
className="h-[43px] min-h-[43px] rounded-none border-none bg-transparent px-3 hover:bg-transparent focus:bg-transparent focus:shadow-none"
onValueChange={value => onValueChange?.(value)}
onChange={e => onValueChange?.(e.target.value)}
onFocus={() => {
setOnPlaceholderClicked(true)
setIsFocus(true)

View File

@ -0,0 +1,77 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { describe, expect, it, vi } from 'vitest'
import TextArea from '../index'
describe('TextArea', () => {
it('should render correctly with default props', () => {
render(<TextArea value="" onChange={vi.fn()} />)
const textarea = screen.getByTestId('text-area')
expect(textarea).toBeInTheDocument()
expect(textarea).toHaveValue('')
})
it('should handle value and onChange correctly', async () => {
const user = userEvent.setup()
const handleChange = vi.fn()
const { rerender } = render(<TextArea value="initial" onChange={handleChange} />)
const textarea = screen.getByTestId('text-area')
expect(textarea).toHaveValue('initial')
await user.type(textarea, ' updated')
expect(handleChange).toHaveBeenCalled()
rerender(<TextArea value="initial updated" onChange={handleChange} />)
expect(textarea).toHaveValue('initial updated')
})
it('should handle autoFocus correctly', () => {
render(<TextArea value="" onChange={vi.fn()} autoFocus />)
const textarea = screen.getByTestId('text-area')
expect(textarea).toHaveFocus()
})
it('should handle disabled state', () => {
render(<TextArea value="" onChange={vi.fn()} disabled />)
const textarea = screen.getByTestId('text-area')
expect(textarea).toBeDisabled()
expect(textarea).toHaveClass('cursor-not-allowed')
})
it('should handle placeholder', () => {
render(<TextArea value="" onChange={vi.fn()} placeholder="Enter text here" />)
expect(screen.getByPlaceholderText('Enter text here')).toBeInTheDocument()
})
it('should handle className', () => {
render(<TextArea value="" onChange={vi.fn()} className="custom-class" />)
expect(screen.getByTestId('text-area')).toHaveClass('custom-class')
})
it('should handle size variants', () => {
const { rerender } = render(<TextArea value="" onChange={vi.fn()} size="small" />)
expect(screen.getByTestId('text-area')).toHaveClass('py-1')
rerender(<TextArea value="" onChange={vi.fn()} size="large" />)
expect(screen.getByTestId('text-area')).toHaveClass('px-4')
})
it('should handle destructive state', () => {
render(<TextArea value="" onChange={vi.fn()} destructive />)
expect(screen.getByTestId('text-area')).toHaveClass('border-components-input-border-destructive')
})
it('should handle onFocus and onBlur', async () => {
const user = userEvent.setup()
const handleFocus = vi.fn()
const handleBlur = vi.fn()
render(<TextArea value="" onChange={vi.fn()} onFocus={handleFocus} onBlur={handleBlur} />)
const textarea = screen.getByTestId('text-area')
await user.click(textarea)
expect(handleFocus).toHaveBeenCalled()
await user.tab()
expect(handleBlur).toHaveBeenCalled()
})
})

View File

@ -0,0 +1,562 @@
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
import { useState } from 'react'
import Textarea from '.'
const meta = {
title: 'Base/Data Entry/Textarea',
component: Textarea,
parameters: {
layout: 'centered',
docs: {
description: {
component: 'Textarea component with multiple sizes (small, regular, large). Built with class-variance-authority for consistent styling.',
},
},
},
tags: ['autodocs'],
argTypes: {
size: {
control: 'select',
options: ['small', 'regular', 'large'],
description: 'Textarea size',
},
value: {
control: 'text',
description: 'Textarea value',
},
placeholder: {
control: 'text',
description: 'Placeholder text',
},
disabled: {
control: 'boolean',
description: 'Disabled state',
},
destructive: {
control: 'boolean',
description: 'Error/destructive state',
},
rows: {
control: 'number',
description: 'Number of visible text rows',
},
},
} satisfies Meta<typeof Textarea>
export default meta
type Story = StoryObj<typeof meta>
// Interactive demo wrapper
const TextareaDemo = (args: any) => {
const [value, setValue] = useState(args.value || '')
return (
<div style={{ width: '500px' }}>
<Textarea
{...args}
value={value}
onChange={(e) => {
setValue(e.target.value)
console.log('Textarea changed:', e.target.value)
}}
/>
{value && (
<div className="mt-3 text-sm text-gray-600">
Character count:
{' '}
<span className="font-semibold">{value.length}</span>
</div>
)}
</div>
)
}
// Default state
export const Default: Story = {
render: args => <TextareaDemo {...args} />,
args: {
size: 'regular',
placeholder: 'Enter text...',
rows: 4,
value: '',
},
}
// Small size
export const SmallSize: Story = {
render: args => <TextareaDemo {...args} />,
args: {
size: 'small',
placeholder: 'Small textarea...',
rows: 3,
value: '',
},
}
// Large size
export const LargeSize: Story = {
render: args => <TextareaDemo {...args} />,
args: {
size: 'large',
placeholder: 'Large textarea...',
rows: 5,
value: '',
},
}
// With initial value
export const WithInitialValue: Story = {
render: args => <TextareaDemo {...args} />,
args: {
size: 'regular',
value: 'This is some initial text content.\n\nIt spans multiple lines.',
rows: 4,
},
}
// Disabled state
export const Disabled: Story = {
render: args => <TextareaDemo {...args} />,
args: {
size: 'regular',
value: 'This textarea is disabled and cannot be edited.',
disabled: true,
rows: 3,
},
}
// Destructive/error state
export const DestructiveState: Story = {
render: args => <TextareaDemo {...args} />,
args: {
size: 'regular',
value: 'This content has an error.',
destructive: true,
rows: 3,
},
}
// Size comparison
const SizeComparisonDemo = () => {
const [small, setSmall] = useState('')
const [regular, setRegular] = useState('')
const [large, setLarge] = useState('')
return (
<div style={{ width: '600px' }} className="space-y-4">
<div>
<label className="mb-2 block text-xs font-medium text-gray-600">Small</label>
<Textarea
size="small"
value={small}
onChange={e => setSmall(e.target.value)}
placeholder="Small textarea..."
rows={3}
/>
</div>
<div>
<label className="mb-2 block text-xs font-medium text-gray-600">Regular</label>
<Textarea
size="regular"
value={regular}
onChange={e => setRegular(e.target.value)}
placeholder="Regular textarea..."
rows={4}
/>
</div>
<div>
<label className="mb-2 block text-xs font-medium text-gray-600">Large</label>
<Textarea
size="large"
value={large}
onChange={e => setLarge(e.target.value)}
placeholder="Large textarea..."
rows={5}
/>
</div>
</div>
)
}
export const SizeComparison: Story = {
render: () => <SizeComparisonDemo />,
parameters: { controls: { disable: true } },
} as unknown as Story
// State comparison
const StateComparisonDemo = () => {
const [normal, setNormal] = useState('Normal state')
const [error, setError] = useState('Error state')
return (
<div style={{ width: '500px' }} className="space-y-4">
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">Normal</label>
<Textarea
value={normal}
onChange={e => setNormal(e.target.value)}
rows={3}
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">Destructive</label>
<Textarea
value={error}
onChange={e => setError(e.target.value)}
destructive
rows={3}
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">Disabled</label>
<Textarea
value="Disabled state"
onChange={() => undefined}
disabled
rows={3}
/>
</div>
</div>
)
}
export const StateComparison: Story = {
render: () => <StateComparisonDemo />,
parameters: { controls: { disable: true } },
} as unknown as Story
// Real-world example - Comment form
const CommentFormDemo = () => {
const [comment, setComment] = useState('')
const maxLength = 500
return (
<div style={{ width: '600px' }} className="rounded-lg border border-gray-200 bg-white p-6">
<h3 className="mb-4 text-lg font-semibold">Leave a Comment</h3>
<Textarea
value={comment}
onChange={e => setComment(e.target.value)}
placeholder="Share your thoughts..."
rows={5}
maxLength={maxLength}
/>
<div className="mt-2 flex items-center justify-between">
<span className="text-xs text-gray-500">
{comment.length}
{' '}
/
{maxLength}
{' '}
characters
</span>
<button
className="rounded-lg bg-blue-600 px-4 py-2 text-sm font-medium text-white hover:bg-blue-700 disabled:cursor-not-allowed disabled:opacity-50"
disabled={comment.trim().length === 0}
>
Post Comment
</button>
</div>
</div>
)
}
export const CommentForm: Story = {
render: () => <CommentFormDemo />,
parameters: { controls: { disable: true } },
} as unknown as Story
// Real-world example - Feedback form
const FeedbackFormDemo = () => {
const [feedback, setFeedback] = useState('')
const [email, setEmail] = useState('')
return (
<div style={{ width: '600px' }} className="rounded-lg border border-gray-200 bg-white p-6">
<h3 className="mb-2 text-lg font-semibold">Send Feedback</h3>
<p className="mb-4 text-sm text-gray-600">Help us improve our product</p>
<div className="space-y-4">
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">Your Email</label>
<input
type="email"
className="w-full rounded-lg border border-gray-300 px-3 py-2 text-sm"
value={email}
onChange={e => setEmail(e.target.value)}
placeholder="email@example.com"
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">Your Feedback</label>
<Textarea
value={feedback}
onChange={e => setFeedback(e.target.value)}
placeholder="Tell us what you think..."
rows={6}
/>
</div>
<button className="w-full rounded-lg bg-green-600 px-4 py-2 text-sm font-medium text-white hover:bg-green-700">
Submit Feedback
</button>
</div>
</div>
)
}
export const FeedbackForm: Story = {
render: () => <FeedbackFormDemo />,
parameters: { controls: { disable: true } },
} as unknown as Story
// Real-world example - Code snippet
const CodeSnippetDemo = () => {
const [code, setCode] = useState(`function hello() {
console.log("Hello, world!");
}`)
return (
<div style={{ width: '600px' }} className="rounded-lg border border-gray-200 bg-white p-6">
<h3 className="mb-4 text-lg font-semibold">Code Editor</h3>
<Textarea
value={code}
onChange={e => setCode(e.target.value)}
className="font-mono"
rows={8}
/>
<div className="mt-4 flex gap-2">
<button className="rounded-lg bg-blue-600 px-4 py-2 text-sm font-medium text-white hover:bg-blue-700">
Run Code
</button>
<button className="rounded-lg bg-gray-200 px-4 py-2 text-sm font-medium text-gray-700 hover:bg-gray-300">
Copy
</button>
</div>
</div>
)
}
export const CodeSnippet: Story = {
render: () => <CodeSnippetDemo />,
parameters: { controls: { disable: true } },
} as unknown as Story
// Real-world example - Message composer
const MessageComposerDemo = () => {
const [message, setMessage] = useState('')
return (
<div style={{ width: '600px' }} className="rounded-lg border border-gray-200 bg-white p-6">
<h3 className="mb-4 text-lg font-semibold">Compose Message</h3>
<div className="space-y-4">
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">To</label>
<input
type="text"
className="w-full rounded-lg border border-gray-300 px-3 py-2 text-sm"
placeholder="Recipient name"
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">Subject</label>
<input
type="text"
className="w-full rounded-lg border border-gray-300 px-3 py-2 text-sm"
placeholder="Message subject"
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">Message</label>
<Textarea
value={message}
onChange={e => setMessage(e.target.value)}
placeholder="Type your message here..."
rows={8}
/>
</div>
<div className="flex gap-2">
<button className="rounded-lg bg-blue-600 px-4 py-2 text-sm font-medium text-white hover:bg-blue-700">
Send Message
</button>
<button className="rounded-lg bg-gray-200 px-4 py-2 text-sm font-medium text-gray-700 hover:bg-gray-300">
Save Draft
</button>
</div>
</div>
</div>
)
}
export const MessageComposer: Story = {
render: () => <MessageComposerDemo />,
parameters: { controls: { disable: true } },
} as unknown as Story
// Real-world example - Bio editor
const BioEditorDemo = () => {
const [bio, setBio] = useState('Software developer passionate about building great products.')
const maxLength = 200
return (
<div style={{ width: '600px' }} className="rounded-lg border border-gray-200 bg-white p-6">
<h3 className="mb-4 text-lg font-semibold">Edit Your Bio</h3>
<Textarea
value={bio}
onChange={e => setBio(e.target.value.slice(0, maxLength))}
placeholder="Tell us about yourself..."
rows={4}
/>
<div className="mt-2 flex items-center justify-between text-xs">
<span className={bio.length > maxLength * 0.9 ? 'text-orange-600' : 'text-gray-500'}>
{bio.length}
{' '}
/
{maxLength}
{' '}
characters
</span>
{bio.length > maxLength * 0.9 && (
<span className="text-orange-600">
{maxLength - bio.length}
{' '}
characters remaining
</span>
)}
</div>
<div className="mt-4 rounded-lg bg-gray-50 p-4">
<div className="mb-2 text-xs font-medium text-gray-600">Preview:</div>
<p className="text-sm text-gray-800">{bio || 'Your bio will appear here...'}</p>
</div>
</div>
)
}
export const BioEditor: Story = {
render: () => <BioEditorDemo />,
parameters: { controls: { disable: true } },
} as unknown as Story
// Real-world example - JSON editor
const JSONEditorDemo = () => {
const [json, setJson] = useState(`{
"name": "John Doe",
"age": 30,
"email": "john@example.com"
}`)
const [isValid, setIsValid] = useState(true)
const validateJSON = (value: string) => {
try {
JSON.parse(value)
setIsValid(true)
}
catch {
setIsValid(false)
}
}
return (
<div style={{ width: '600px' }} className="rounded-lg border border-gray-200 bg-white p-6">
<div className="mb-4 flex items-center justify-between">
<h3 className="text-lg font-semibold">JSON Editor</h3>
<span className={`rounded-sm px-2 py-1 text-xs ${isValid ? 'bg-green-100 text-green-700' : 'bg-red-100 text-red-700'}`}>
{isValid ? '✓ Valid' : '✗ Invalid'}
</span>
</div>
<Textarea
value={json}
onChange={(e) => {
setJson(e.target.value)
validateJSON(e.target.value)
}}
className="font-mono"
destructive={!isValid}
rows={10}
/>
<div className="mt-4 flex gap-2">
<button className="rounded-lg bg-blue-600 px-4 py-2 text-sm font-medium text-white hover:bg-blue-700 disabled:opacity-50" disabled={!isValid}>
Save JSON
</button>
<button
className="rounded-lg bg-gray-200 px-4 py-2 text-sm font-medium text-gray-700 hover:bg-gray-300"
onClick={() => {
try {
const formatted = JSON.stringify(JSON.parse(json), null, 2)
setJson(formatted)
}
catch {
// Invalid JSON, do nothing
}
}}
>
Format
</button>
</div>
</div>
)
}
export const JSONEditor: Story = {
render: () => <JSONEditorDemo />,
parameters: { controls: { disable: true } },
} as unknown as Story
// Real-world example - Task description
const TaskDescriptionDemo = () => {
const [title, setTitle] = useState('Implement user authentication')
const [description, setDescription] = useState('Add login and registration functionality with JWT tokens.')
return (
<div style={{ width: '600px' }} className="rounded-lg border border-gray-200 bg-white p-6">
<h3 className="mb-4 text-lg font-semibold">Create New Task</h3>
<div className="space-y-4">
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">Task Title</label>
<input
type="text"
className="w-full rounded-lg border border-gray-300 px-3 py-2 text-sm"
value={title}
onChange={e => setTitle(e.target.value)}
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">Description</label>
<Textarea
value={description}
onChange={e => setDescription(e.target.value)}
placeholder="Describe the task in detail..."
rows={6}
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700">Priority</label>
<select className="w-full rounded-lg border border-gray-300 px-3 py-2 text-sm">
<option>Low</option>
<option>Medium</option>
<option>High</option>
<option>Urgent</option>
</select>
</div>
<button className="w-full rounded-lg bg-blue-600 px-4 py-2 text-sm font-medium text-white hover:bg-blue-700">
Create Task
</button>
</div>
</div>
)
}
export const TaskDescription: Story = {
render: () => <TaskDescriptionDemo />,
parameters: { controls: { disable: true } },
} as unknown as Story
// Interactive playground
export const Playground: Story = {
render: args => <TextareaDemo {...args} />,
args: {
size: 'regular',
placeholder: 'Enter text...',
rows: 4,
disabled: false,
destructive: false,
value: '',
},
}

View File

@ -0,0 +1,60 @@
import type { VariantProps } from 'class-variance-authority'
import type { CSSProperties } from 'react'
import { cn } from '@langgenius/dify-ui/cn'
import { cva } from 'class-variance-authority'
import * as React from 'react'
const textareaVariants = cva(
'',
{
variants: {
size: {
small: 'rounded-md py-1 system-xs-regular',
regular: 'rounded-md px-3 system-sm-regular',
large: 'rounded-lg px-4 system-md-regular',
},
},
defaultVariants: {
size: 'regular',
},
},
)
export type TextareaProps = {
value: string | number
disabled?: boolean
destructive?: boolean
styleCss?: CSSProperties
ref?: React.Ref<HTMLTextAreaElement>
onFocus?: React.FocusEventHandler<HTMLTextAreaElement>
onBlur?: React.FocusEventHandler<HTMLTextAreaElement>
} & React.TextareaHTMLAttributes<HTMLTextAreaElement> & VariantProps<typeof textareaVariants>
const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
({ className, value, onChange, disabled, size, destructive, styleCss, onFocus, onBlur, ...props }, ref) => {
return (
<textarea
ref={ref}
onFocus={onFocus}
onBlur={onBlur}
style={styleCss}
className={cn(
'min-h-20 w-full appearance-none border border-transparent bg-components-input-bg-normal p-2 text-components-input-text-filled caret-primary-600 outline-hidden placeholder:text-components-input-text-placeholder hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:border-components-input-border-active focus:bg-components-input-bg-active focus:shadow-xs',
textareaVariants({ size }),
disabled && 'cursor-not-allowed border-transparent bg-components-input-bg-disabled text-components-input-text-filled-disabled hover:border-transparent hover:bg-components-input-bg-disabled',
destructive && 'border-components-input-border-destructive bg-components-input-bg-destructive text-components-input-text-filled hover:border-components-input-border-destructive hover:bg-components-input-bg-destructive focus:border-components-input-border-destructive focus:bg-components-input-bg-destructive',
className,
)}
value={value ?? ''}
onChange={onChange}
disabled={disabled}
data-testid="text-area"
{...props}
>
</textarea>
)
},
)
Textarea.displayName = 'Textarea'
export default Textarea

View File

@ -25,7 +25,6 @@ const AppsFull: FC<{ loc: string, className?: string }> = ({
const total = plan.total.buildApps
const percent = total > 0 ? (usage / total) * 100 : 0
const tone: MeterTone = percent >= 80 ? 'error' : percent >= 50 ? 'warning' : 'neutral'
const buildAppsLabel = t('usagePage.buildApps', { ns: 'billing' })
return (
<div className={cn(
'flex flex-col gap-3 rounded-xl border-[0.5px] border-components-panel-border-subtle bg-components-panel-on-panel-item-bg p-4 shadow-xs backdrop-blur-xs',
@ -62,14 +61,14 @@ const AppsFull: FC<{ loc: string, className?: string }> = ({
</div>
<div className="flex flex-col gap-2">
<div className="flex items-center justify-between system-xs-medium text-text-secondary">
<div>{buildAppsLabel}</div>
<div>{t('usagePage.buildApps', { ns: 'billing' })}</div>
<div>
{usage}
/
{total}
</div>
</div>
<MeterRoot value={Math.min(percent, 100)} max={100} aria-label={buildAppsLabel}>
<MeterRoot value={Math.min(percent, 100)} max={100}>
<MeterTrack>
<MeterIndicator tone={tone} />
</MeterTrack>

View File

@ -229,7 +229,7 @@ describe('UsageInfo', () => {
/>,
)
expect(screen.getByRole('meter', { name: 'Storage' })).toBeInTheDocument()
expect(screen.getByRole('meter')).toBeInTheDocument()
expect(container.querySelector('[aria-hidden="true"]')).toBeNull()
})
@ -270,7 +270,7 @@ describe('UsageInfo', () => {
/>,
)
expect(screen.getByRole('meter', { name: 'Storage' })).toBeInTheDocument()
expect(screen.getByRole('meter')).toBeInTheDocument()
expect(container.querySelector('[aria-hidden="true"]')).toBeNull()
})

View File

@ -144,13 +144,13 @@ const UsageInfo: FC<Props> = ({
<div
className={cn(
'h-1 rounded-md bg-progress-bar-indeterminate-stripe',
isSandboxPlan ? 'w-full' : 'w-7.5',
isSandboxPlan ? 'w-full' : 'w-[30px]',
)}
/>
</div>
)
: (
<MeterRoot value={effectivePercent} max={100} aria-label={name}>
<MeterRoot value={effectivePercent} max={100}>
<MeterTrack>
<MeterIndicator tone={tone} />
</MeterTrack>
@ -162,7 +162,7 @@ const UsageInfo: FC<Props> = ({
return (
<Tooltip>
<TooltipTrigger render={<div className="cursor-default">{children}</div>} />
<TooltipContent className="w-50 max-w-50">
<TooltipContent className="w-[200px] max-w-[200px]">
{storageTooltip}
</TooltipContent>
</Tooltip>

View File

@ -1,7 +1,6 @@
import type { AppIconSelection } from '@/app/components/base/app-icon-picker'
import type { PipelineTemplate } from '@/models/pipeline'
import { Button } from '@langgenius/dify-ui/button'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { toast } from '@langgenius/dify-ui/toast'
import { RiCloseLine } from '@remixicon/react'
import * as React from 'react'
@ -10,6 +9,7 @@ import { useTranslation } from 'react-i18next'
import AppIcon from '@/app/components/base/app-icon'
import AppIconPicker from '@/app/components/base/app-icon-picker'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import { useInvalidCustomizedTemplateList, useUpdateTemplateInfo } from '@/service/use-pipeline'
type EditPipelineInfoProps = {
@ -45,7 +45,8 @@ const EditPipelineInfo = ({
setAppIcon(icon)
}, [])
const handleDescriptionChange = useCallback((value: string) => {
const handleDescriptionChange = useCallback((event: React.ChangeEvent<HTMLTextAreaElement>) => {
const value = event.target.value
setDescription(value)
}, [])
@ -120,8 +121,7 @@ const EditPipelineInfo = ({
{t('knowledgeDescription', { ns: 'datasetPipeline' })}
</label>
<Textarea
aria-label={t('knowledgeDescription', { ns: 'datasetPipeline' })}
onValueChange={handleDescriptionChange}
onChange={handleDescriptionChange}
value={description}
placeholder={t('knowledgeDescriptionPlaceholder', { ns: 'datasetPipeline' })}
/>

View File

@ -244,15 +244,6 @@ describe('DatasetCard Component', () => {
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents')
})
it('should not change background color on hover', () => {
const dataset = createMockDataset()
render(<DatasetCard dataset={dataset} />)
const card = screen.getByText('Test Dataset').closest('[data-disable-nprogress]')
expect(card).toHaveClass('bg-components-card-bg')
expect(card).not.toHaveClass('hover:bg-components-card-bg-alt')
})
it('should navigate to hitTesting for external provider', () => {
const dataset = createMockDataset({ provider: 'external' })
render(<DatasetCard dataset={dataset} />)

View File

@ -63,7 +63,7 @@ const DatasetCard = ({
return (
<>
<div
className="group relative col-span-1 flex h-47.5 cursor-pointer flex-col rounded-xl border-[0.5px] border-solid border-components-card-border bg-components-card-bg shadow-xs shadow-shadow-shadow-3 transition-all duration-200 ease-in-out hover:shadow-md hover:shadow-shadow-shadow-5"
className="group relative col-span-1 flex h-[190px] cursor-pointer flex-col rounded-xl border-[0.5px] border-solid border-components-card-border bg-components-card-bg shadow-xs shadow-shadow-shadow-3 transition-all duration-200 ease-in-out hover:bg-components-card-bg-alt hover:shadow-md hover:shadow-shadow-shadow-5"
data-disable-nprogress={true}
onClick={handleCardClick}
>

View File

@ -5,12 +5,12 @@ import type { DataSet } from '@/models/datasets'
import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { Dialog, DialogContent } from '@langgenius/dify-ui/dialog'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { toast } from '@langgenius/dify-ui/toast'
import { RiCloseLine } from '@remixicon/react'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import { updateDatasetSetting } from '@/service/datasets'
import AppIcon from '../../base/app-icon'
import AppIconPicker from '../../base/app-icon-picker'
@ -108,7 +108,7 @@ const RenameDatasetModal = ({ show, dataset, onSuccess, onClose }: RenameDataset
{t('form.desc', { ns: 'datasetSettings' })}
</div>
<div className="w-full">
<Textarea aria-label={t('form.desc', { ns: 'datasetSettings' })} value={description} onValueChange={value => setDescription(value)} className="resize-none" placeholder={t('form.descPlaceholder', { ns: 'datasetSettings' }) || ''} />
<Textarea value={description} onChange={e => setDescription(e.target.value)} className="resize-none" placeholder={t('form.descPlaceholder', { ns: 'datasetSettings' }) || ''} />
</div>
</div>
</div>

View File

@ -3,11 +3,11 @@ import type { AppIconSelection } from '@/app/components/base/app-icon-picker'
import type { Member } from '@/models/common'
import type { DataSet, DatasetPermission, IconInfo } from '@/models/datasets'
import type { AppIconType } from '@/types/app'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { useTranslation } from 'react-i18next'
import AppIcon from '@/app/components/base/app-icon'
import AppIconPicker from '@/app/components/base/app-icon-picker'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import PermissionSelector from '../../permission-selector'
const rowClass = 'flex gap-x-1'
@ -85,12 +85,11 @@ const BasicInfoSection = ({
</div>
<div className="grow">
<Textarea
aria-label={t('form.desc', { ns: 'datasetSettings' })}
disabled={!currentDataset?.embedding_available}
className="resize-none"
placeholder={t('form.descPlaceholder', { ns: 'datasetSettings' }) || ''}
value={description}
onValueChange={value => setDescription(value)}
onChange={e => setDescription(e.target.value)}
/>
</div>
</div>

View File

@ -1,7 +1,7 @@
import type { ChangeEvent } from 'react'
import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { SummaryIndexSetting as SummaryIndexSettingType } from '@/models/datasets'
import { Switch } from '@langgenius/dify-ui/switch'
import { Textarea } from '@langgenius/dify-ui/textarea'
import {
memo,
useCallback,
@ -9,6 +9,7 @@ import {
} from 'react'
import { useTranslation } from 'react-i18next'
import { Infotip } from '@/app/components/base/infotip'
import Textarea from '@/app/components/base/textarea'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
@ -52,9 +53,9 @@ const SummaryIndexSetting = ({
})
}, [onSummaryIndexSettingChange])
const handleSummaryIndexPromptChange = useCallback((value: string) => {
const handleSummaryIndexPromptChange = useCallback((e: ChangeEvent<HTMLTextAreaElement>) => {
onSummaryIndexSettingChange?.({
summary_prompt: value,
summary_prompt: e.target.value,
})
}, [onSummaryIndexSettingChange])
@ -94,9 +95,8 @@ const SummaryIndexSetting = ({
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
</div>
<Textarea
aria-label={t('form.summaryInstructions', { ns: 'datasetSettings' })}
value={summaryIndexSetting?.summary_prompt ?? ''}
onValueChange={handleSummaryIndexPromptChange}
onChange={handleSummaryIndexPromptChange}
disabled={readonly}
placeholder={t('form.summaryInstructionsPlaceholder', { ns: 'datasetSettings' })}
/>
@ -166,9 +166,8 @@ const SummaryIndexSetting = ({
</div>
<div className="grow">
<Textarea
aria-label={t('form.summaryInstructions', { ns: 'datasetSettings' })}
value={summaryIndexSetting?.summary_prompt ?? ''}
onValueChange={handleSummaryIndexPromptChange}
onChange={handleSummaryIndexPromptChange}
disabled={readonly}
placeholder={t('form.summaryInstructionsPlaceholder', { ns: 'datasetSettings' })}
/>
@ -215,9 +214,8 @@ const SummaryIndexSetting = ({
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
</div>
<Textarea
aria-label={t('form.summaryInstructions', { ns: 'datasetSettings' })}
value={summaryIndexSetting?.summary_prompt ?? ''}
onValueChange={handleSummaryIndexPromptChange}
onChange={handleSummaryIndexPromptChange}
disabled={readonly}
placeholder={t('form.summaryInstructionsPlaceholder', { ns: 'datasetSettings' })}
/>

View File

@ -3,7 +3,6 @@ import type { AppIconType } from '@/types/app'
import { Button } from '@langgenius/dify-ui/button'
import { Dialog, DialogCloseButton, DialogContent, DialogTitle } from '@langgenius/dify-ui/dialog'
import { Switch } from '@langgenius/dify-ui/switch'
import { Textarea } from '@langgenius/dify-ui/textarea'
import { toast } from '@langgenius/dify-ui/toast'
import { useDebounceFn, useKeyPress } from 'ahooks'
import * as React from 'react'
@ -11,6 +10,7 @@ import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import AppIcon from '@/app/components/base/app-icon'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import AppsFull from '@/app/components/billing/apps-full-in-dialog'
import { useProviderContext } from '@/context/provider-context'
import { AppModeEnum } from '@/types/app'
@ -145,11 +145,10 @@ const CreateAppModal = ({
<div className="pt-2">
<div className="py-2 text-sm leading-[20px] font-medium text-text-primary">{t('newApp.captionDescription', { ns: 'app' })}</div>
<Textarea
aria-label={t('newApp.captionDescription', { ns: 'app' })}
className="resize-none"
placeholder={t('newApp.appDescriptionPlaceholder', { ns: 'app' }) || ''}
value={description}
onValueChange={value => setDescription(value)}
onChange={e => setDescription(e.target.value)}
/>
</div>
{/* answer icon */}

Some files were not shown because too many files have changed in this diff Show More