refactor: consolidate LLM runtime model state on ModelInstance (#32746)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-01 02:29:32 +08:00
committed by GitHub
parent 48d8667c4f
commit 962df17a15
20 changed files with 375 additions and 324 deletions

View File

@ -5,20 +5,16 @@ from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@ -29,46 +25,14 @@ from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError
def fetch_model_config(
*,
node_data_model: ModelConfig,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
model_instance.model_name,
model_instance.credentials,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
raise ValueError(f"Model schema not found for {model_instance.model_name}")
return model_schema
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:

View File

@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
@ -38,7 +37,7 @@ from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -83,7 +82,6 @@ from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
)
from .exc import (
InvalidContextStructureError,
@ -116,6 +114,7 @@ class LLMNode(Node[LLMNodeData]):
_llm_file_saver: LLMFileSaver
_credentials_provider: CredentialsProvider
_model_factory: ModelFactory
_model_instance: ModelInstance
def __init__(
self,
@ -126,6 +125,7 @@ class LLMNode(Node[LLMNodeData]):
*,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_instance: ModelInstance,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@ -139,6 +139,7 @@ class LLMNode(Node[LLMNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@ -202,21 +203,10 @@ class LLMNode(Node[LLMNodeData]):
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance, model_config = self._fetch_model_config(
node_data_model=self.node_data.model,
)
model_name = getattr(model_instance, "model_name", None)
if not isinstance(model_name, str):
model_name = model_config.model
model_provider = getattr(model_instance, "provider", None)
if not isinstance(model_provider, str):
model_provider = model_config.provider
model_schema = model_instance.model_type_instance.get_model_schema(
model_name,
model_instance.credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_name}")
model_instance = self._model_instance
model_name = model_instance.model_name
model_provider = model_instance.provider
model_stop = model_instance.stop
# fetch memory
memory = llm_utils.fetch_memory(
@ -240,9 +230,7 @@ class LLMNode(Node[LLMNodeData]):
context=context,
memory=memory,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=self.node_data.model.completion_params,
stop=model_config.stop,
stop=model_stop,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
@ -254,7 +242,6 @@ class LLMNode(Node[LLMNodeData]):
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=self.node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
@ -371,7 +358,6 @@ class LLMNode(Node[LLMNodeData]):
@staticmethod
def invoke_llm(
*,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None = None,
@ -384,11 +370,10 @@ class LLMNode(Node[LLMNodeData]):
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
)
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
model_parameters = model_instance.parameters
invoke_model_parameters = dict(model_parameters)
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema(
@ -402,7 +387,7 @@ class LLMNode(Node[LLMNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=output_schema,
model_parameters=node_data_model.completion_params,
model_parameters=invoke_model_parameters,
stop=list(stop or []),
stream=True,
user=user_id,
@ -412,7 +397,7 @@ class LLMNode(Node[LLMNodeData]):
invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params,
model_parameters=invoke_model_parameters,
stop=list(stop or []),
stream=True,
user=user_id,
@ -771,23 +756,6 @@ class LLMNode(Node[LLMNodeData]):
return None
def _fetch_model_config(
self,
*,
node_data_model: ModelConfig,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
node_data_model=node_data_model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
completion_params = model_config_with_cred.parameters
model_config_with_cred.parameters = completion_params
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
node_data_model.completion_params = completion_params
return model, model_config_with_cred
@staticmethod
def fetch_prompt_messages(
*,
@ -796,8 +764,6 @@ class LLMNode(Node[LLMNodeData]):
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
@ -808,6 +774,7 @@ class LLMNode(Node[LLMNodeData]):
context_files: list[File] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
# For chat model
@ -826,8 +793,6 @@ class LLMNode(Node[LLMNodeData]):
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
@ -865,8 +830,6 @@ class LLMNode(Node[LLMNodeData]):
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
@ -1316,23 +1279,23 @@ def _calculate_rest_token(
*,
prompt_messages: list[PromptMessage],
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> int:
rest_tokens = 2000
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_schema.parameter_rules:
for parameter_rule in runtime_model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_parameters.get(parameter_rule.name)
or model_parameters.get(str(parameter_rule.use_template))
runtime_model_parameters.get(parameter_rule.name)
or runtime_model_parameters.get(str(parameter_rule.use_template))
or 0
)
@ -1347,8 +1310,6 @@ def _handle_memory_chat_mode(
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
@ -1356,8 +1317,6 @@ def _handle_memory_chat_mode(
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
@ -1371,8 +1330,6 @@ def _handle_memory_completion_mode(
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> str:
memory_text = ""
# Get history text from memory for completion model
@ -1380,8 +1337,6 @@ def _handle_memory_completion_mode(
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")