mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
Merge remote-tracking branch 'origin/main' into optional-plugin-invoke
# Conflicts: # api/dify_graph/nodes/llm/llm_utils.py # api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py # api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py # api/tests/unit_tests/libs/test_login.py
This commit is contained in:
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from dify_graph.file import FileType, file_manager
|
||||
@ -37,6 +40,11 @@ from .runtime_protocols import PreparedLLMProtocol
|
||||
|
||||
CONTEXT_PLACEHOLDER = "{{#context#}}"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}")
|
||||
MAX_RESOLVED_VALUE_LENGTH = 1024
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: PreparedLLMProtocol) -> AIModelEntity:
|
||||
model_schema = model_instance.get_model_schema()
|
||||
@ -477,3 +485,61 @@ def _append_file_prompts(
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
|
||||
def _coerce_resolved_value(raw: str) -> int | float | bool | str:
|
||||
"""Try to restore the original type from a resolved template string.
|
||||
|
||||
Variable references are always resolved to text, but completion params may
|
||||
expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to
|
||||
the ``temperature`` parameter). This helper attempts a JSON parse so that
|
||||
``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not
|
||||
valid JSON literals are returned as-is.
|
||||
"""
|
||||
stripped = raw.strip()
|
||||
if not stripped:
|
||||
return raw
|
||||
|
||||
try:
|
||||
parsed: object = json.loads(stripped)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return raw
|
||||
|
||||
if isinstance(parsed, (int, float, bool)):
|
||||
return parsed
|
||||
return raw
|
||||
|
||||
|
||||
def resolve_completion_params_variables(
|
||||
completion_params: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params.
|
||||
|
||||
Security notes:
|
||||
- Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to
|
||||
prevent denial-of-service through excessively large variable payloads.
|
||||
- This follows the same ``VariablePool.convert_template`` pattern used across
|
||||
Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream
|
||||
model plugin receives these values as structured JSON key-value pairs — they
|
||||
are never concatenated into raw HTTP headers or SQL queries.
|
||||
- Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are
|
||||
restored to their native type rather than sent as a bare string.
|
||||
"""
|
||||
resolved: dict[str, Any] = {}
|
||||
for key, value in completion_params.items():
|
||||
if isinstance(value, str) and VARIABLE_PATTERN.search(value):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
text = segment_group.text
|
||||
if len(text) > MAX_RESOLVED_VALUE_LENGTH:
|
||||
logger.warning(
|
||||
"Resolved value for param '%s' truncated from %d to %d chars",
|
||||
key,
|
||||
len(text),
|
||||
MAX_RESOLVED_VALUE_LENGTH,
|
||||
)
|
||||
text = text[:MAX_RESOLVED_VALUE_LENGTH]
|
||||
resolved[key] = _coerce_resolved_value(text)
|
||||
else:
|
||||
resolved[key] = value
|
||||
return resolved
|
||||
|
||||
@ -207,6 +207,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# fetch model config
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
model_name = model_instance.model_name
|
||||
model_provider = model_instance.provider
|
||||
model_stop = model_instance.stop
|
||||
|
||||
@ -159,6 +159,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
try:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
|
||||
@ -111,6 +111,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
variables = {"query": query}
|
||||
# fetch model instance
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
memory = self._memory
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
|
||||
@ -18,15 +18,23 @@ if TYPE_CHECKING:
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def _resolve_current_user() -> EndUser | Account | None:
|
||||
"""
|
||||
Resolve the current user proxy to its underlying user object.
|
||||
This keeps unit tests working when they patch `current_user` directly
|
||||
instead of bootstrapping a full Flask-Login manager.
|
||||
"""
|
||||
user_proxy = current_user
|
||||
get_current_object = getattr(user_proxy, "_get_current_object", None)
|
||||
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
||||
|
||||
|
||||
def current_account_with_tenant():
|
||||
"""
|
||||
Resolve the underlying account for the current user proxy and ensure tenant context exists.
|
||||
Allows tests to supply plain Account mocks without the LocalProxy helper.
|
||||
"""
|
||||
user_proxy = current_user
|
||||
|
||||
get_current_object = getattr(user_proxy, "_get_current_object", None)
|
||||
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
|
||||
user = _resolve_current_user()
|
||||
|
||||
if not isinstance(user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@ -77,9 +85,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
user = _get_user()
|
||||
user = _resolve_current_user()
|
||||
if user is None or not user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
g._login_user = user
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
check_csrf_token(request, user.id)
|
||||
|
||||
@ -5,7 +5,11 @@ import pytest
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.file import FileTransferMethod, FileType
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
|
||||
from dify_graph.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
@ -80,6 +84,15 @@ def _build_image_file(
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def variable_pool() -> VariablePool:
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["node1", "output"], "resolved_value")
|
||||
pool.add(["node2", "text"], "hello world")
|
||||
pool.add(["start", "user_input"], "dynamic_param")
|
||||
return pool
|
||||
|
||||
|
||||
def _fetch_prompt_messages_with_mocked_content(content):
|
||||
variable_pool = VariablePool.empty()
|
||||
model_instance = mock.MagicMock(spec=ModelInstance)
|
||||
@ -122,6 +135,159 @@ def _fetch_prompt_messages_with_mocked_content(content):
|
||||
)
|
||||
|
||||
|
||||
class TestTypeCoercionViaResolve:
|
||||
"""Type coercion is tested through the public resolve_completion_params_variables API."""
|
||||
|
||||
def test_numeric_string_coerced_to_float(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], "0.7")
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
|
||||
assert result["p"] == 0.7
|
||||
|
||||
def test_integer_string_coerced_to_int(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], "1024")
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
|
||||
assert result["p"] == 1024
|
||||
|
||||
def test_boolean_string_coerced_to_bool(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], "true")
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
|
||||
assert result["p"] is True
|
||||
|
||||
def test_plain_string_stays_string(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], "json_object")
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
|
||||
assert result["p"] == "json_object"
|
||||
|
||||
def test_json_object_string_stays_string(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], '{"key": "val"}')
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
|
||||
assert result["p"] == '{"key": "val"}'
|
||||
|
||||
def test_mixed_text_and_variable_stays_string(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], "0.7")
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "val={{#n.v#}}"}, pool)
|
||||
assert result["p"] == "val=0.7"
|
||||
|
||||
|
||||
class TestResolveCompletionParamsVariables:
|
||||
def test_plain_string_values_unchanged(self, variable_pool: VariablePool):
|
||||
params = {"response_format": "json", "custom_param": "static_value"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"response_format": "json", "custom_param": "static_value"}
|
||||
|
||||
def test_numeric_values_unchanged(self, variable_pool: VariablePool):
|
||||
params = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024}
|
||||
|
||||
def test_boolean_values_unchanged(self, variable_pool: VariablePool):
|
||||
params = {"stream": True, "echo": False}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"stream": True, "echo": False}
|
||||
|
||||
def test_list_values_unchanged(self, variable_pool: VariablePool):
|
||||
params = {"stop": ["Human:", "Assistant:"]}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"stop": ["Human:", "Assistant:"]}
|
||||
|
||||
def test_single_variable_reference_resolved(self, variable_pool: VariablePool):
|
||||
params = {"response_format": "{{#node1.output#}}"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"response_format": "resolved_value"}
|
||||
|
||||
def test_multiple_variable_references_resolved(self, variable_pool: VariablePool):
|
||||
params = {
|
||||
"param_a": "{{#node1.output#}}",
|
||||
"param_b": "{{#node2.text#}}",
|
||||
}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"param_a": "resolved_value", "param_b": "hello world"}
|
||||
|
||||
def test_mixed_text_and_variable_resolved(self, variable_pool: VariablePool):
|
||||
params = {"prompt_prefix": "prefix_{{#node1.output#}}_suffix"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"prompt_prefix": "prefix_resolved_value_suffix"}
|
||||
|
||||
def test_mixed_params_types(self, variable_pool: VariablePool):
|
||||
"""Non-string params pass through; string params with variables get resolved."""
|
||||
params = {
|
||||
"temperature": 0.7,
|
||||
"response_format": "{{#node1.output#}}",
|
||||
"custom_string": "no_vars_here",
|
||||
"max_tokens": 512,
|
||||
"stop": ["\n"],
|
||||
}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {
|
||||
"temperature": 0.7,
|
||||
"response_format": "resolved_value",
|
||||
"custom_string": "no_vars_here",
|
||||
"max_tokens": 512,
|
||||
"stop": ["\n"],
|
||||
}
|
||||
|
||||
def test_empty_params(self, variable_pool: VariablePool):
|
||||
result = llm_utils.resolve_completion_params_variables({}, variable_pool)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_unresolvable_variable_keeps_selector_text(self):
|
||||
"""When a referenced variable doesn't exist in the pool, convert_template
|
||||
falls back to the raw selector path (e.g. 'nonexistent.var')."""
|
||||
pool = VariablePool.empty()
|
||||
params = {"format": "{{#nonexistent.var#}}"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, pool)
|
||||
|
||||
assert result["format"] == "nonexistent.var"
|
||||
|
||||
def test_multiple_variables_in_single_value(self, variable_pool: VariablePool):
|
||||
params = {"combined": "{{#node1.output#}} and {{#node2.text#}}"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"combined": "resolved_value and hello world"}
|
||||
|
||||
def test_original_params_not_mutated(self, variable_pool: VariablePool):
|
||||
original = {"response_format": "{{#node1.output#}}", "temperature": 0.5}
|
||||
original_copy = dict(original)
|
||||
|
||||
_ = llm_utils.resolve_completion_params_variables(original, variable_pool)
|
||||
|
||||
assert original == original_copy
|
||||
|
||||
def test_long_value_truncated(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["node1", "big"], "x" * 2000)
|
||||
params = {"param": "{{#node1.big#}}"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, pool)
|
||||
|
||||
assert len(result["param"]) == llm_utils.MAX_RESOLVED_VALUE_LENGTH
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
|
||||
with pytest.raises(NoPromptFoundError):
|
||||
_fetch_prompt_messages_with_mocked_content(
|
||||
|
||||
@ -139,7 +139,6 @@ class TestLoginRequired:
|
||||
|
||||
with login_app.test_request_context(method=method):
|
||||
result = protected_view()
|
||||
|
||||
assert result == "Protected content"
|
||||
get_user.assert_not_called()
|
||||
ensure_sync_spy.assert_called_once_with(protected_view.__wrapped__)
|
||||
|
||||
Reference in New Issue
Block a user