Merge remote-tracking branch 'origin/main' into optional-plugin-invoke

# Conflicts:
#	api/dify_graph/nodes/llm/llm_utils.py
#	api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
#	api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
#	api/tests/unit_tests/libs/test_login.py
This commit is contained in:
WH-2099
2026-03-24 19:35:09 +08:00
95 changed files with 9989 additions and 2581 deletions

View File

@ -1,6 +1,9 @@
from __future__ import annotations
from collections.abc import Sequence
import json
import logging
import re
from collections.abc import Mapping, Sequence
from typing import Any
from dify_graph.file import FileType, file_manager
@ -37,6 +40,11 @@ from .runtime_protocols import PreparedLLMProtocol
CONTEXT_PLACEHOLDER = "{{#context#}}"
logger = logging.getLogger(__name__)
VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}")
MAX_RESOLVED_VALUE_LENGTH = 1024
def fetch_model_schema(*, model_instance: PreparedLLMProtocol) -> AIModelEntity:
model_schema = model_instance.get_model_schema()
@ -477,3 +485,61 @@ def _append_file_prompts(
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
def _coerce_resolved_value(raw: str) -> int | float | bool | str:
"""Try to restore the original type from a resolved template string.
Variable references are always resolved to text, but completion params may
expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to
the ``temperature`` parameter). This helper attempts a JSON parse so that
``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not
valid JSON literals are returned as-is.
"""
stripped = raw.strip()
if not stripped:
return raw
try:
parsed: object = json.loads(stripped)
except (json.JSONDecodeError, ValueError):
return raw
if isinstance(parsed, (int, float, bool)):
return parsed
return raw
def resolve_completion_params_variables(
completion_params: Mapping[str, Any],
variable_pool: VariablePool,
) -> dict[str, Any]:
"""Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params.
Security notes:
- Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to
prevent denial-of-service through excessively large variable payloads.
- This follows the same ``VariablePool.convert_template`` pattern used across
Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream
model plugin receives these values as structured JSON key-value pairs — they
are never concatenated into raw HTTP headers or SQL queries.
- Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are
restored to their native type rather than sent as a bare string.
"""
resolved: dict[str, Any] = {}
for key, value in completion_params.items():
if isinstance(value, str) and VARIABLE_PATTERN.search(value):
segment_group = variable_pool.convert_template(value)
text = segment_group.text
if len(text) > MAX_RESOLVED_VALUE_LENGTH:
logger.warning(
"Resolved value for param '%s' truncated from %d to %d chars",
key,
len(text),
MAX_RESOLVED_VALUE_LENGTH,
)
text = text[:MAX_RESOLVED_VALUE_LENGTH]
resolved[key] = _coerce_resolved_value(text)
else:
resolved[key] = value
return resolved

View File

@ -207,6 +207,10 @@ class LLMNode(Node[LLMNodeData]):
# fetch model config
model_instance = self._model_instance
# Resolve variable references in string-typed completion params
model_instance.parameters = llm_utils.resolve_completion_params_variables(
model_instance.parameters, variable_pool
)
model_name = model_instance.model_name
model_provider = model_instance.provider
model_stop = model_instance.stop

View File

@ -159,6 +159,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
model_instance = self._model_instance
# Resolve variable references in string-typed completion params
model_instance.parameters = llm_utils.resolve_completion_params_variables(
model_instance.parameters, variable_pool
)
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:

View File

@ -111,6 +111,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
# Resolve variable references in string-typed completion params
model_instance.parameters = llm_utils.resolve_completion_params_variables(
model_instance.parameters, variable_pool
)
memory = self._memory
# fetch instruction
node_data.instruction = node_data.instruction or ""

View File

@ -18,15 +18,23 @@ if TYPE_CHECKING:
from models.model import EndUser
def _resolve_current_user() -> EndUser | Account | None:
"""
Resolve the current user proxy to its underlying user object.
This keeps unit tests working when they patch `current_user` directly
instead of bootstrapping a full Flask-Login manager.
"""
user_proxy = current_user
get_current_object = getattr(user_proxy, "_get_current_object", None)
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
def current_account_with_tenant():
"""
Resolve the underlying account for the current user proxy and ensure tenant context exists.
Allows tests to supply plain Account mocks without the LocalProxy helper.
"""
user_proxy = current_user
get_current_object = getattr(user_proxy, "_get_current_object", None)
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
user = _resolve_current_user()
if not isinstance(user, Account):
raise ValueError("current_user must be an Account instance")
@ -77,9 +85,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
return current_app.ensure_sync(func)(*args, **kwargs)
user = _get_user()
user = _resolve_current_user()
if user is None or not user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
g._login_user = user
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
check_csrf_token(request, user.id)

View File

@ -5,7 +5,11 @@ import pytest
from core.model_manager import ModelInstance
from dify_graph.file import FileTransferMethod, FileType
from dify_graph.file.models import File
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessageRole,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
SystemPromptMessage,
@ -80,6 +84,15 @@ def _build_image_file(
)
@pytest.fixture
def variable_pool() -> VariablePool:
pool = VariablePool.empty()
pool.add(["node1", "output"], "resolved_value")
pool.add(["node2", "text"], "hello world")
pool.add(["start", "user_input"], "dynamic_param")
return pool
def _fetch_prompt_messages_with_mocked_content(content):
variable_pool = VariablePool.empty()
model_instance = mock.MagicMock(spec=ModelInstance)
@ -122,6 +135,159 @@ def _fetch_prompt_messages_with_mocked_content(content):
)
class TestTypeCoercionViaResolve:
"""Type coercion is tested through the public resolve_completion_params_variables API."""
def test_numeric_string_coerced_to_float(self):
pool = VariablePool.empty()
pool.add(["n", "v"], "0.7")
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
assert result["p"] == 0.7
def test_integer_string_coerced_to_int(self):
pool = VariablePool.empty()
pool.add(["n", "v"], "1024")
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
assert result["p"] == 1024
def test_boolean_string_coerced_to_bool(self):
pool = VariablePool.empty()
pool.add(["n", "v"], "true")
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
assert result["p"] is True
def test_plain_string_stays_string(self):
pool = VariablePool.empty()
pool.add(["n", "v"], "json_object")
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
assert result["p"] == "json_object"
def test_json_object_string_stays_string(self):
pool = VariablePool.empty()
pool.add(["n", "v"], '{"key": "val"}')
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
assert result["p"] == '{"key": "val"}'
def test_mixed_text_and_variable_stays_string(self):
pool = VariablePool.empty()
pool.add(["n", "v"], "0.7")
result = llm_utils.resolve_completion_params_variables({"p": "val={{#n.v#}}"}, pool)
assert result["p"] == "val=0.7"
class TestResolveCompletionParamsVariables:
def test_plain_string_values_unchanged(self, variable_pool: VariablePool):
params = {"response_format": "json", "custom_param": "static_value"}
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
assert result == {"response_format": "json", "custom_param": "static_value"}
def test_numeric_values_unchanged(self, variable_pool: VariablePool):
params = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024}
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024}
def test_boolean_values_unchanged(self, variable_pool: VariablePool):
params = {"stream": True, "echo": False}
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
assert result == {"stream": True, "echo": False}
def test_list_values_unchanged(self, variable_pool: VariablePool):
params = {"stop": ["Human:", "Assistant:"]}
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
assert result == {"stop": ["Human:", "Assistant:"]}
def test_single_variable_reference_resolved(self, variable_pool: VariablePool):
params = {"response_format": "{{#node1.output#}}"}
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
assert result == {"response_format": "resolved_value"}
def test_multiple_variable_references_resolved(self, variable_pool: VariablePool):
params = {
"param_a": "{{#node1.output#}}",
"param_b": "{{#node2.text#}}",
}
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
assert result == {"param_a": "resolved_value", "param_b": "hello world"}
def test_mixed_text_and_variable_resolved(self, variable_pool: VariablePool):
params = {"prompt_prefix": "prefix_{{#node1.output#}}_suffix"}
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
assert result == {"prompt_prefix": "prefix_resolved_value_suffix"}
def test_mixed_params_types(self, variable_pool: VariablePool):
"""Non-string params pass through; string params with variables get resolved."""
params = {
"temperature": 0.7,
"response_format": "{{#node1.output#}}",
"custom_string": "no_vars_here",
"max_tokens": 512,
"stop": ["\n"],
}
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
assert result == {
"temperature": 0.7,
"response_format": "resolved_value",
"custom_string": "no_vars_here",
"max_tokens": 512,
"stop": ["\n"],
}
def test_empty_params(self, variable_pool: VariablePool):
result = llm_utils.resolve_completion_params_variables({}, variable_pool)
assert result == {}
def test_unresolvable_variable_keeps_selector_text(self):
"""When a referenced variable doesn't exist in the pool, convert_template
falls back to the raw selector path (e.g. 'nonexistent.var')."""
pool = VariablePool.empty()
params = {"format": "{{#nonexistent.var#}}"}
result = llm_utils.resolve_completion_params_variables(params, pool)
assert result["format"] == "nonexistent.var"
def test_multiple_variables_in_single_value(self, variable_pool: VariablePool):
params = {"combined": "{{#node1.output#}} and {{#node2.text#}}"}
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
assert result == {"combined": "resolved_value and hello world"}
def test_original_params_not_mutated(self, variable_pool: VariablePool):
original = {"response_format": "{{#node1.output#}}", "temperature": 0.5}
original_copy = dict(original)
_ = llm_utils.resolve_completion_params_variables(original, variable_pool)
assert original == original_copy
def test_long_value_truncated(self):
pool = VariablePool.empty()
pool.add(["node1", "big"], "x" * 2000)
params = {"param": "{{#node1.big#}}"}
result = llm_utils.resolve_completion_params_variables(params, pool)
assert len(result["param"]) == llm_utils.MAX_RESOLVED_VALUE_LENGTH
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
with pytest.raises(NoPromptFoundError):
_fetch_prompt_messages_with_mocked_content(

View File

@ -139,7 +139,6 @@ class TestLoginRequired:
with login_app.test_request_context(method=method):
result = protected_view()
assert result == "Protected content"
get_user.assert_not_called()
ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__)