diff --git a/.gemini/config.yaml b/.gemini/config.yaml new file mode 100644 index 0000000000..15c697730a --- /dev/null +++ b/.gemini/config.yaml @@ -0,0 +1,13 @@ +have_fun: false +memory_config: + disabled: false +code_review: + disable: true + comment_severity_threshold: MEDIUM + max_review_comments: -1 + pull_request_opened: + help: false + summary: false + code_review: false + include_drafts: false +ignore_patterns: [] diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py index 87aa645ed1..cbfb461c74 100644 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ b/api/dify_graph/nodes/llm/llm_utils.py @@ -1,6 +1,9 @@ from __future__ import annotations -from collections.abc import Sequence +import json +import logging +import re +from collections.abc import Mapping, Sequence from typing import Any, cast from sqlalchemy import select @@ -47,6 +50,11 @@ from .exc import ( ) from .protocols import TemplateRenderer +logger = logging.getLogger(__name__) + +VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}") +MAX_RESOLVED_VALUE_LENGTH = 1024 + def fetch_model_config(*, tenant_id: str, node_data_model: ModelConfig) -> tuple[ModelInstance, Any]: from core.app.llm.model_access import build_dify_model_access @@ -688,3 +696,61 @@ def _append_file_prompts( prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) else: prompt_messages.append(UserPromptMessage(content=file_prompts)) + + +def _coerce_resolved_value(raw: str) -> int | float | bool | str: + """Try to restore the original type from a resolved template string. + + Variable references are always resolved to text, but completion params may + expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to + the ``temperature`` parameter). This helper attempts a JSON parse so that + ``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not + valid JSON literals are returned as-is. + """ + stripped = raw.strip() + if not stripped: + return raw + + try: + parsed: object = json.loads(stripped) + except (json.JSONDecodeError, ValueError): + return raw + + if isinstance(parsed, (int, float, bool)): + return parsed + return raw + + +def resolve_completion_params_variables( + completion_params: Mapping[str, Any], + variable_pool: VariablePool, +) -> dict[str, Any]: + """Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params. + + Security notes: + - Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to + prevent denial-of-service through excessively large variable payloads. + - This follows the same ``VariablePool.convert_template`` pattern used across + Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream + model plugin receives these values as structured JSON key-value pairs — they + are never concatenated into raw HTTP headers or SQL queries. + - Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are + restored to their native type rather than sent as a bare string. + """ + resolved: dict[str, Any] = {} + for key, value in completion_params.items(): + if isinstance(value, str) and VARIABLE_PATTERN.search(value): + segment_group = variable_pool.convert_template(value) + text = segment_group.text + if len(text) > MAX_RESOLVED_VALUE_LENGTH: + logger.warning( + "Resolved value for param '%s' truncated from %d to %d chars", + key, + len(text), + MAX_RESOLVED_VALUE_LENGTH, + ) + text = text[:MAX_RESOLVED_VALUE_LENGTH] + resolved[key] = _coerce_resolved_value(text) + else: + resolved[key] = value + return resolved diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 53ff609f6c..3a308f59bf 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -246,6 +246,13 @@ class LLMNode(Node[LLMNodeData]): node_data_model=self.node_data.model, tenant_id=self.tenant_id, ) + resolved_completion_params = llm_utils.resolve_completion_params_variables( + model_config.parameters, + variable_pool, + ) + model_instance.parameters = resolved_completion_params + model_config.parameters = resolved_completion_params + self.node_data.model.completion_params = resolved_completion_params # fetch memory memory = llm_utils.fetch_memory( diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 2dedd5e162..7b14b7adc1 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -164,6 +164,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): raise InvalidModelTypeError("Model is not a Large Language Model") diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 1a9e6a4ca1..6f85af045c 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -113,6 +113,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): query = variable.value if variable else None variables = {"query": query} model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) memory = self._memory # fetch instruction node_data.instruction = node_data.instruction or "" diff --git a/api/libs/login.py b/api/libs/login.py index bd5cb5f30d..dce332b01d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -18,15 +18,23 @@ if TYPE_CHECKING: from models.model import EndUser +def _resolve_current_user() -> EndUser | Account | None: + """ + Resolve the current user proxy to its underlying user object. + This keeps unit tests working when they patch `current_user` directly + instead of bootstrapping a full Flask-Login manager. + """ + user_proxy = current_user + get_current_object = getattr(user_proxy, "_get_current_object", None) + return get_current_object() if callable(get_current_object) else user_proxy # type: ignore + + def current_account_with_tenant(): """ Resolve the underlying account for the current user proxy and ensure tenant context exists. Allows tests to supply plain Account mocks without the LocalProxy helper. """ - user_proxy = current_user - - get_current_object = getattr(user_proxy, "_get_current_object", None) - user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore + user = _resolve_current_user() if not isinstance(user, Account): raise ValueError("current_user must be an Account instance") @@ -79,9 +87,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue] if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: return current_app.ensure_sync(func)(*args, **kwargs) - user = _get_user() + user = _resolve_current_user() if user is None or not user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore + g._login_user = user # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. check_csrf_token(request, user.id) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index 03c4b983a9..57380a3842 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -23,6 +23,15 @@ from dify_graph.nodes.llm.llm_utils import ( from dify_graph.runtime import VariablePool +@pytest.fixture +def variable_pool() -> VariablePool: + pool = VariablePool.empty() + pool.add(["node1", "output"], "resolved_value") + pool.add(["node2", "text"], "hello world") + pool.add(["start", "user_input"], "dynamic_param") + return pool + + class TestTruncateMultimodalContent: """Tests for _truncate_multimodal_content function.""" @@ -399,6 +408,309 @@ def _fetch_prompt_messages_with_mocked_content(content): ) +class TestTypeCoercionViaResolve: + """Type coercion is tested through the public resolve_completion_params_variables API.""" + + def test_numeric_string_coerced_to_float(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 0.7 + + def test_integer_string_coerced_to_int(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "1024") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 1024 + + def test_boolean_string_coerced_to_bool(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "true") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] is True + + def test_plain_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "json_object") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == "json_object" + + def test_json_object_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], '{"key": "val"}') + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == '{"key": "val"}' + + def test_mixed_text_and_variable_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "val={{#n.v#}}"}, pool) + assert result["p"] == "val=0.7" + + +class TestResolveCompletionParamsVariables: + def test_plain_string_values_unchanged(self, variable_pool: VariablePool): + params = {"response_format": "json", "custom_param": "static_value"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "json", "custom_param": "static_value"} + + def test_numeric_values_unchanged(self, variable_pool: VariablePool): + params = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + def test_boolean_values_unchanged(self, variable_pool: VariablePool): + params = {"stream": True, "echo": False} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stream": True, "echo": False} + + def test_list_values_unchanged(self, variable_pool: VariablePool): + params = {"stop": ["Human:", "Assistant:"]} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stop": ["Human:", "Assistant:"]} + + def test_single_variable_reference_resolved(self, variable_pool: VariablePool): + params = {"response_format": "{{#node1.output#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "resolved_value"} + + def test_multiple_variable_references_resolved(self, variable_pool: VariablePool): + params = { + "param_a": "{{#node1.output#}}", + "param_b": "{{#node2.text#}}", + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"param_a": "resolved_value", "param_b": "hello world"} + + def test_mixed_text_and_variable_resolved(self, variable_pool: VariablePool): + params = {"prompt_prefix": "prefix_{{#node1.output#}}_suffix"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"prompt_prefix": "prefix_resolved_value_suffix"} + + def test_mixed_params_types(self, variable_pool: VariablePool): + params = { + "temperature": 0.7, + "response_format": "{{#node1.output#}}", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == { + "temperature": 0.7, + "response_format": "resolved_value", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + def test_empty_params(self, variable_pool: VariablePool): + result = llm_utils.resolve_completion_params_variables({}, variable_pool) + + assert result == {} + + def test_unresolvable_variable_keeps_selector_text(self): + pool = VariablePool.empty() + params = {"format": "{{#nonexistent.var#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert result["format"] == "nonexistent.var" + + def test_multiple_variables_in_single_value(self, variable_pool: VariablePool): + params = {"combined": "{{#node1.output#}} and {{#node2.text#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"combined": "resolved_value and hello world"} + + def test_original_params_not_mutated(self, variable_pool: VariablePool): + original = {"response_format": "{{#node1.output#}}", "temperature": 0.5} + original_copy = dict(original) + + _ = llm_utils.resolve_completion_params_variables(original, variable_pool) + + assert original == original_copy + + def test_long_value_truncated(self): + pool = VariablePool.empty() + pool.add(["node1", "big"], "x" * 2000) + params = {"param": "{{#node1.big#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert len(result["param"]) == llm_utils.MAX_RESOLVED_VALUE_LENGTH + + +class TestTypeCoercionViaResolve: + """Type coercion is tested through the public resolve_completion_params_variables API.""" + + def test_numeric_string_coerced_to_float(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 0.7 + + def test_integer_string_coerced_to_int(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "1024") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 1024 + + def test_boolean_string_coerced_to_bool(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "true") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] is True + + def test_plain_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "json_object") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == "json_object" + + def test_json_object_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], '{"key": "val"}') + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == '{"key": "val"}' + + def test_mixed_text_and_variable_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "val={{#n.v#}}"}, pool) + assert result["p"] == "val=0.7" + + +class TestResolveCompletionParamsVariables: + def test_plain_string_values_unchanged(self, variable_pool: VariablePool): + params = {"response_format": "json", "custom_param": "static_value"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "json", "custom_param": "static_value"} + + def test_numeric_values_unchanged(self, variable_pool: VariablePool): + params = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + def test_boolean_values_unchanged(self, variable_pool: VariablePool): + params = {"stream": True, "echo": False} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stream": True, "echo": False} + + def test_list_values_unchanged(self, variable_pool: VariablePool): + params = {"stop": ["Human:", "Assistant:"]} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stop": ["Human:", "Assistant:"]} + + def test_single_variable_reference_resolved(self, variable_pool: VariablePool): + params = {"response_format": "{{#node1.output#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "resolved_value"} + + def test_multiple_variable_references_resolved(self, variable_pool: VariablePool): + params = { + "param_a": "{{#node1.output#}}", + "param_b": "{{#node2.text#}}", + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"param_a": "resolved_value", "param_b": "hello world"} + + def test_mixed_text_and_variable_resolved(self, variable_pool: VariablePool): + params = {"prompt_prefix": "prefix_{{#node1.output#}}_suffix"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"prompt_prefix": "prefix_resolved_value_suffix"} + + def test_mixed_params_types(self, variable_pool: VariablePool): + """Non-string params pass through; string params with variables get resolved.""" + params = { + "temperature": 0.7, + "response_format": "{{#node1.output#}}", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == { + "temperature": 0.7, + "response_format": "resolved_value", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + def test_empty_params(self, variable_pool: VariablePool): + result = llm_utils.resolve_completion_params_variables({}, variable_pool) + + assert result == {} + + def test_unresolvable_variable_keeps_selector_text(self): + """When a referenced variable doesn't exist in the pool, convert_template + falls back to the raw selector path (e.g. 'nonexistent.var').""" + pool = VariablePool.empty() + params = {"format": "{{#nonexistent.var#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert result["format"] == "nonexistent.var" + + def test_multiple_variables_in_single_value(self, variable_pool: VariablePool): + params = {"combined": "{{#node1.output#}} and {{#node2.text#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"combined": "resolved_value and hello world"} + + def test_original_params_not_mutated(self, variable_pool: VariablePool): + original = {"response_format": "{{#node1.output#}}", "temperature": 0.5} + original_copy = dict(original) + + _ = llm_utils.resolve_completion_params_variables(original, variable_pool) + + assert original == original_copy + + def test_long_value_truncated(self): + pool = VariablePool.empty() + pool.add(["node1", "big"], "x" * 2000) + params = {"param": "{{#node1.big#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert len(result["param"]) == llm_utils.MAX_RESOLVED_VALUE_LENGTH + + def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out(): with pytest.raises(NoPromptFoundError): _fetch_prompt_messages_with_mocked_content( diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index a94ba0c00b..8613d89215 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -130,6 +130,25 @@ class TestLoginRequired: assert result == "Synced content" setup_app.ensure_sync.assert_called_once() + @patch("libs.login.check_csrf_token", mock_csrf_check) + def test_patched_current_user_without_login_manager(self, app: Flask): + """Test that patched current_user bypasses login manager bootstrapping.""" + + @login_required + def protected_view(): + return "Protected content" + + mock_user = MockUser("test_user", is_authenticated=True) + mock_proxy = MagicMock() + mock_proxy._get_current_object.return_value = mock_user + + with app.test_request_context(): + app.ensure_sync = lambda func: func + with patch("libs.login.current_user", mock_proxy): + result = protected_view() + assert result == "Protected content" + assert g._login_user == mock_user + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_flask_1_compatibility(self, setup_app: Flask): """Test Flask 1.x compatibility without ensure_sync.""" diff --git a/web/app/components/billing/pricing/__tests__/header.spec.tsx b/web/app/components/billing/pricing/__tests__/header.spec.tsx index 0aadc3b0ce..cb8991ff42 100644 --- a/web/app/components/billing/pricing/__tests__/header.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/header.spec.tsx @@ -1,12 +1,14 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' -import { Dialog } from '@/app/components/base/ui/dialog' +import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' import Header from '../header' function renderHeader(onClose: () => void) { return render( -
+ +
+
, ) } @@ -24,7 +26,7 @@ describe('Header', () => { expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() }) }) @@ -33,7 +35,7 @@ describe('Header', () => { const handleClose = vi.fn() renderHeader(handleClose) - fireEvent.click(screen.getByRole('button')) + fireEvent.click(screen.getByRole('button', { name: 'common.operation.close' })) expect(handleClose).toHaveBeenCalledTimes(1) }) @@ -41,11 +43,11 @@ describe('Header', () => { describe('Edge Cases', () => { it('should render structural elements with translation keys', () => { - const { container } = renderHeader(vi.fn()) + renderHeader(vi.fn()) - expect(container.querySelector('span')).toBeInTheDocument() - expect(container.querySelector('p')).toBeInTheDocument() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() }) }) }) diff --git a/web/app/components/billing/pricing/__tests__/index.spec.tsx b/web/app/components/billing/pricing/__tests__/index.spec.tsx index 36848cd463..a8d0a4329e 100644 --- a/web/app/components/billing/pricing/__tests__/index.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/index.spec.tsx @@ -68,6 +68,7 @@ describe('Pricing', () => { it('should render pricing header and localized footer link', () => { render() + expect(screen.getByRole('dialog', { name: 'billing.plansCommon.title.plans' })).toBeInTheDocument() expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByTestId('pricing-link')).toHaveAttribute('href', 'https://dify.ai/en/pricing#plans-and-features') }) diff --git a/web/app/components/billing/pricing/footer.tsx b/web/app/components/billing/pricing/footer.tsx index 0d3fd965b0..1422ec1cb1 100644 --- a/web/app/components/billing/pricing/footer.tsx +++ b/web/app/components/billing/pricing/footer.tsx @@ -28,8 +28,9 @@ const Footer = ({ {t('plansCommon.comparePlanAndFeatures', { ns: 'billing' })} diff --git a/web/app/components/billing/pricing/header.tsx b/web/app/components/billing/pricing/header.tsx index d0ffe100db..5ab1895667 100644 --- a/web/app/components/billing/pricing/header.tsx +++ b/web/app/components/billing/pricing/header.tsx @@ -1,5 +1,6 @@ import * as React from 'react' import { useTranslation } from 'react-i18next' +import { DialogDescription, DialogTitle } from '@/app/components/base/ui/dialog' import { cn } from '@/utils/classnames' import Button from '../../base/button' import DifyLogo from '../../base/logo/dify-logo' @@ -18,24 +19,25 @@ const Header = ({
-
+ - {t('plansCommon.title.plans', { ns: 'billing' })} - +
-

+ {t('plansCommon.title.description', { ns: 'billing' })} -

+ @@ -119,7 +125,6 @@ describe('ModelParameterModal', () => { beforeEach(() => { vi.clearAllMocks() - isAPIKeySet = true isRulesLoading = false parameterRules = [ { @@ -233,6 +238,26 @@ describe('ModelParameterModal', () => { expect(screen.getByTestId('model-selector')).toBeInTheDocument() }) + it('should pass nodesOutputVars and availableNodes to ParameterItem', () => { + const mockNodesOutputVars = [{ nodeId: 'n1', title: 'Node', vars: [] }] + const mockAvailableNodes = [{ id: 'n1', data: { title: 'Node', type: 'llm' } }] + + render( + , + ) + + fireEvent.click(screen.getByText('Open Settings')) + + const paramEl = screen.getByTestId('param-temperature') + expect(paramEl).toHaveAttribute('data-has-nodes-output-vars', 'true') + expect(paramEl).toHaveAttribute('data-has-available-nodes', 'true') + }) + it('should support custom triggers, workflow mode, and missing default model values', async () => { render( ({ @@ -18,6 +23,29 @@ vi.mock('@/app/components/base/tag-input', () => ({ ), })) +let promptEditorOnChange: ((text: string) => void) | undefined +let capturedWorkflowNodesMap: Record | undefined + +vi.mock('@/app/components/base/prompt-editor', () => ({ + default: ({ value, onChange, workflowVariableBlock }: { + value: string + onChange: (text: string) => void + workflowVariableBlock?: { + show: boolean + variables: NodeOutPutVar[] + workflowNodesMap?: Record + } + }) => { + promptEditorOnChange = onChange + capturedWorkflowNodesMap = workflowVariableBlock?.workflowNodesMap + return ( +
+ {value} +
+ ) + }, +})) + describe('ParameterItem', () => { const createRule = (overrides: Partial = {}): ModelParameterRule => ({ name: 'temp', @@ -30,9 +58,10 @@ describe('ParameterItem', () => { beforeEach(() => { vi.clearAllMocks() + promptEditorOnChange = undefined + capturedWorkflowNodesMap = undefined }) - // Float tests it('should render float controls and clamp numeric input to max', () => { const onChange = vi.fn() render() @@ -50,7 +79,6 @@ describe('ParameterItem', () => { expect(onChange).toHaveBeenCalledWith(0.1) }) - // Int tests it('should render int controls and clamp numeric input', () => { const onChange = vi.fn() render() @@ -75,22 +103,17 @@ describe('ParameterItem', () => { it('should render int input without slider if min or max is missing', () => { render() expect(screen.queryByRole('slider')).not.toBeInTheDocument() - // No max -> precision step expect(screen.getByRole('spinbutton')).toHaveAttribute('step', '0') }) - // Slider events (uses generic value mock for slider) it('should handle slide change and clamp values', () => { const onChange = vi.fn() render() - // Test that the actual slider triggers the onChange logic correctly - // The implementation of Slider uses onChange(val) directly via the mock fireEvent.click(screen.getByTestId('slider-btn')) expect(onChange).toHaveBeenCalledWith(2) }) - // Text & String tests it('should render exact string input and propagate text changes', () => { const onChange = vi.fn() render() @@ -109,21 +132,17 @@ describe('ParameterItem', () => { it('should render select for string with options', () => { render() - // Select renders the selected value in the trigger expect(screen.getByText('a')).toBeInTheDocument() }) - // Tag Tests it('should render tag input for tag type', () => { const onChange = vi.fn() render() expect(screen.getByText('placeholder')).toBeInTheDocument() - // Trigger mock tag input fireEvent.click(screen.getByTestId('tag-input')) expect(onChange).toHaveBeenCalledWith(['tag1', 'tag2']) }) - // Boolean tests it('should render boolean radios and update value on click', () => { const onChange = vi.fn() render() @@ -131,7 +150,6 @@ describe('ParameterItem', () => { expect(onChange).toHaveBeenCalledWith(false) }) - // Switch tests it('should call onSwitch with current value when optional switch is toggled off', () => { const onSwitch = vi.fn() render() @@ -146,7 +164,6 @@ describe('ParameterItem', () => { expect(screen.queryByRole('switch')).not.toBeInTheDocument() }) - // Default Value Fallbacks (rendering without value) it('should use default values if value is undefined', () => { const { rerender } = render() expect(screen.getByRole('spinbutton')).toHaveValue(0.5) @@ -158,26 +175,102 @@ describe('ParameterItem', () => { expect(screen.getByText('True')).toBeInTheDocument() expect(screen.getByText('False')).toBeInTheDocument() - // Without default - rerender() // min is 0 by default in createRule + rerender() expect(screen.getByRole('spinbutton')).toHaveValue(0) }) - // Input Blur it('should reset input to actual bound value on blur', () => { render() const input = screen.getByRole('spinbutton') - // change local state (which triggers clamp internally to let's say 1.4 -> 1 but leaves input text, though handleInputChange updates local state) - // Actually our test fires a change so localValue = 1, then blur sets it fireEvent.change(input, { target: { value: '5' } }) fireEvent.blur(input) expect(input).toHaveValue(1) }) - // Unsupported it('should render no input for unsupported parameter type', () => { render() expect(screen.queryByRole('textbox')).not.toBeInTheDocument() expect(screen.queryByRole('spinbutton')).not.toBeInTheDocument() }) + + describe('workflow variable reference', () => { + const mockNodesOutputVars: NodeOutPutVar[] = [ + { nodeId: 'node1', title: 'LLM Node', vars: [] }, + ] + const mockAvailableNodes: Node[] = [ + { id: 'node1', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'LLM Node', type: BlockEnum.LLM } } as Node, + { id: 'start', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'Start', type: BlockEnum.Start } } as Node, + ] + + it('should build workflowNodesMap and render PromptEditor for string type', () => { + const onChange = vi.fn() + render( + , + ) + + const editor = screen.getByTestId('prompt-editor') + expect(editor).toBeInTheDocument() + expect(editor).toHaveAttribute('data-has-workflow-vars', 'true') + expect(capturedWorkflowNodesMap).toBeDefined() + expect(capturedWorkflowNodesMap!.node1.title).toBe('LLM Node') + expect(capturedWorkflowNodesMap!.sys.title).toBe('workflow.blocks.start') + expect(capturedWorkflowNodesMap!.sys.type).toBe(BlockEnum.Start) + + promptEditorOnChange?.('updated text') + expect(onChange).toHaveBeenCalledWith('updated text') + }) + + it('should build workflowNodesMap and render PromptEditor for text type', () => { + const onChange = vi.fn() + render( + , + ) + + const editor = screen.getByTestId('prompt-editor') + expect(editor).toBeInTheDocument() + expect(editor).toHaveAttribute('data-has-workflow-vars', 'true') + expect(capturedWorkflowNodesMap).toBeDefined() + + promptEditorOnChange?.('new long text') + expect(onChange).toHaveBeenCalledWith('new long text') + }) + + it('should fall back to plain input when not in workflow mode for string type', () => { + render( + , + ) + + expect(screen.queryByTestId('prompt-editor')).not.toBeInTheDocument() + expect(screen.getByRole('textbox')).toBeInTheDocument() + }) + + it('should return undefined workflowNodesMap when not in workflow mode', () => { + render( + , + ) + + expect(capturedWorkflowNodesMap).toBeUndefined() + }) + }) }) diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx index ac245a4154..6494c81734 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx @@ -9,6 +9,10 @@ import type { } from '../declarations' import type { ParameterValue } from './parameter-item' import type { TriggerProps } from './trigger' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' @@ -46,6 +50,8 @@ export type ModelParameterModalProps = { readonly?: boolean isInWorkflow?: boolean scope?: string + nodesOutputVars?: NodeOutPutVar[] + availableNodes?: Node[] } const ModelParameterModal: FC = ({ @@ -62,11 +68,18 @@ const ModelParameterModal: FC = ({ renderTrigger, readonly, isInWorkflow, + nodesOutputVars, + availableNodes, }) => { const { t } = useTranslation() const [open, setOpen] = useState(false) const settingsIconRef = useRef(null) - const { data: parameterRulesData, isLoading } = useModelParameterRules(provider, modelId) + const { + data: parameterRulesData, + isPending, + isLoading, + } = useModelParameterRules(provider, modelId) + const isRulesLoading = isPending || isLoading const { currentProvider, currentModel, @@ -192,7 +205,7 @@ const ModelParameterModal: FC = ({ }
{ - isLoading + isRulesLoading ?
: ( [ @@ -206,6 +219,8 @@ const ModelParameterModal: FC = ({ onChange={v => handleParamChange(parameter.name, v)} onSwitch={(checked, assignValue) => handleSwitch(parameter.name, checked, assignValue)} isInWorkflow={isInWorkflow} + nodesOutputVars={nodesOutputVars} + availableNodes={availableNodes} /> )) ) @@ -214,7 +229,7 @@ const ModelParameterModal: FC = ({ ) } { - !parameterRules.length && isLoading && ( + !parameterRules.length && isRulesLoading && (
) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index 86fb6d81d0..01e3f45371 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -1,11 +1,18 @@ import type { ModelParameterRule } from '../declarations' -import { useEffect, useRef, useState } from 'react' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' +import { useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import PromptEditor from '@/app/components/base/prompt-editor' import Radio from '@/app/components/base/radio' import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' import TagInput from '@/app/components/base/tag-input' import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/app/components/base/ui/select' import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' +import { BlockEnum } from '@/app/components/workflow/types' import { cn } from '@/utils/classnames' import { useLanguage } from '../hooks' import { isNullOrUndefined } from '../utils' @@ -18,18 +25,43 @@ type ParameterItemProps = { onChange?: (value: ParameterValue) => void onSwitch?: (checked: boolean, assignValue: ParameterValue) => void isInWorkflow?: boolean + nodesOutputVars?: NodeOutPutVar[] + availableNodes?: Node[] } + function ParameterItem({ parameterRule, value, onChange, onSwitch, isInWorkflow, + nodesOutputVars, + availableNodes = [], }: ParameterItemProps) { + const { t } = useTranslation() const language = useLanguage() const [localValue, setLocalValue] = useState(value) const numberInputRef = useRef(null) + const workflowNodesMap = useMemo(() => { + if (!isInWorkflow || !availableNodes.length) + return undefined + + return availableNodes.reduce>>((acc, node) => { + acc[node.id] = { + title: node.data.title, + type: node.data.type, + } + if (node.data.type === BlockEnum.Start) { + acc.sys = { + title: t('blocks.start', { ns: 'workflow' }), + type: BlockEnum.Start, + } + } + return acc + }, {}) + }, [availableNodes, isInWorkflow, t]) + const getDefaultValue = () => { let defaultValue: ParameterValue @@ -196,6 +228,25 @@ function ParameterItem({ } if (parameterRule.type === 'string' && !parameterRule.options?.length) { + if (isInWorkflow && nodesOutputVars) { + return ( +
+ { handleInputChange(text) }} + workflowVariableBlock={{ + show: true, + variables: nodesOutputVars, + workflowNodesMap, + }} + editable + /> +
+ ) + } + return ( + { handleInputChange(text) }} + workflowVariableBlock={{ + show: true, + variables: nodesOutputVars, + workflowNodesMap, + }} + editable + /> +
+ ) + } + return (