refactor(workflow): inject credential/model access ports into LLM nodes (#32569)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-02-27 14:36:41 +08:00
committed by GitHub
parent d20880d102
commit a694533fc9
38 changed files with 676 additions and 179 deletions

View File

@ -8,7 +8,7 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -17,6 +17,8 @@ from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegme
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@ -24,49 +26,46 @@ from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
from .exc import InvalidVariableTypeError
def fetch_model_config(
tenant_id: str, node_data_model: ModelConfig
*,
node_data_model: ModelConfig,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
model = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
# check model
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
# model config
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
return model, ModelConfigWithCredentialsEntity(
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
@ -131,7 +130,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model)
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1

View File

@ -16,7 +16,7 @@ from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
@ -38,11 +38,7 @@ from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelPropertyKey,
ModelType,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -76,6 +72,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
@ -93,7 +90,6 @@ from .exc import (
InvalidVariableTypeError,
LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError,
NoPromptFoundError,
TemplateTypeNotSupportError,
VariableNotFoundError,
@ -118,6 +114,8 @@ class LLMNode(Node[LLMNodeData]):
_file_outputs: list[File]
_llm_file_saver: LLMFileSaver
_credentials_provider: CredentialsProvider
_model_factory: ModelFactory
def __init__(
self,
@ -126,6 +124,8 @@ class LLMNode(Node[LLMNodeData]):
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@ -137,6 +137,9 @@ class LLMNode(Node[LLMNodeData]):
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._credentials_provider = credentials_provider
self._model_factory = model_factory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
@ -199,10 +202,21 @@ class LLMNode(Node[LLMNodeData]):
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
model_instance, model_config = self._fetch_model_config(
node_data_model=self.node_data.model,
tenant_id=self.tenant_id,
)
model_name = getattr(model_instance, "model_name", None)
if not isinstance(model_name, str):
model_name = model_config.model
model_provider = getattr(model_instance, "provider", None)
if not isinstance(model_provider, str):
model_provider = model_config.provider
model_schema = model_instance.model_type_instance.get_model_schema(
model_name,
model_instance.credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_name}")
# fetch memory
memory = llm_utils.fetch_memory(
@ -225,14 +239,16 @@ class LLMNode(Node[LLMNodeData]):
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=self.node_data.model.completion_params,
stop=model_config.stop,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
@ -286,14 +302,14 @@ class LLMNode(Node[LLMNodeData]):
structured_output = event
process_data = {
"model_mode": model_config.mode,
"model_mode": self.node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
model_mode=self.node_data.model.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_config.provider,
"model_name": model_config.model,
"model_provider": model_provider,
"model_name": model_name,
}
outputs = {
@ -755,21 +771,18 @@ class LLMNode(Node[LLMNodeData]):
return None
@staticmethod
def _fetch_model_config(
self,
*,
node_data_model: ModelConfig,
tenant_id: str,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
tenant_id=tenant_id, node_data_model=node_data_model
node_data_model=node_data_model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
completion_params = model_config_with_cred.parameters
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_config_with_cred.parameters = completion_params
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
node_data_model.completion_params = completion_params
@ -782,14 +795,16 @@ class LLMNode(Node[LLMNodeData]):
sys_files: Sequence[File],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
context_files: list[File] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@ -810,7 +825,9 @@ class LLMNode(Node[LLMNodeData]):
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
@ -847,7 +864,9 @@ class LLMNode(Node[LLMNodeData]):
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
@ -924,7 +943,7 @@ class LLMNode(Node[LLMNodeData]):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_config.model_schema.features:
if not model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
@ -934,19 +953,19 @@ class LLMNode(Node[LLMNodeData]):
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_config.model_schema.features
and ModelFeature.VISION not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_config.model_schema.features
and ModelFeature.DOCUMENT not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_config.model_schema.features
and ModelFeature.VIDEO not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_config.model_schema.features
and ModelFeature.AUDIO not in model_schema.features
)
):
continue
@ -965,19 +984,7 @@ class LLMNode(Node[LLMNodeData]):
"Please ensure a prompt is properly configured before proceeding."
)
model = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model,
)
model_schema = model.model_type_instance.get_model_schema(
model=model_config.model,
credentials=model.credentials,
)
if not model_schema:
raise ModelNotExistError(f"Model {model_config.model} not exist.")
return filtered_prompt_messages, model_config.stop
return filtered_prompt_messages, stop
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -1306,26 +1313,26 @@ def _render_jinja2_message(
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
*,
prompt_messages: list[PromptMessage],
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
for parameter_rule in model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
model_parameters.get(parameter_rule.name)
or model_parameters.get(str(parameter_rule.use_template))
or 0
)
@ -1339,12 +1346,19 @@ def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
@ -1356,12 +1370,19 @@ def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(

View File

@ -0,0 +1,21 @@
from __future__ import annotations
from typing import Any, Protocol
from core.model_manager import ModelInstance
class CredentialsProvider(Protocol):
"""Port for loading runtime credentials for a provider/model pair."""
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
"""Return credentials for the target provider/model or raise a domain error."""
...
class ModelFactory(Protocol):
"""Port for creating initialized LLM model instances for execution."""
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...

View File

@ -3,7 +3,7 @@ import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -60,6 +60,11 @@ from .prompts import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import GraphRuntimeState
def extract_json(text):
"""
@ -92,6 +97,27 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._credentials_provider = credentials_provider
self._model_factory = model_factory
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -806,7 +832,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
if not self._model_instance or not self._model_config:
self._model_instance, self._model_config = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
node_data_model=node_data_model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
return self._model_instance, self._model_config

View File

@ -24,6 +24,7 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
@ -49,6 +50,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
def __init__(
self,
@ -57,6 +60,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@ -68,6 +73,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._credentials_provider = credentials_provider
self._model_factory = model_factory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
@ -89,9 +97,16 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variables = {"query": query}
# fetch model config
model_instance, model_config = llm_utils.fetch_model_config(
tenant_id=self.tenant_id,
node_data_model=node_data.model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
model_schema = model_instance.model_type_instance.get_model_schema(
model_instance.model_name,
model_instance.credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_instance.model_name}")
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
@ -133,13 +148,15 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
prompt_template=prompt_template,
sys_query="",
memory=memory,
model_config=model_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=node_data.model.completion_params,
stop=model_config.stop,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""