mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
refactor: consolidate LLM runtime model state on ModelInstance (#32746)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -5,20 +5,16 @@ from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -29,46 +25,14 @@ from models.provider_ids import ModelProviderID
|
||||
from .exc import InvalidVariableTypeError
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
||||
model_instance.model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||
return model_schema
|
||||
|
||||
|
||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
|
||||
|
||||
@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
@ -38,7 +37,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
@ -83,7 +82,6 @@ from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@ -116,6 +114,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: CredentialsProvider
|
||||
_model_factory: ModelFactory
|
||||
_model_instance: ModelInstance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -126,6 +125,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -139,6 +139,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@ -202,21 +203,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(
|
||||
node_data_model=self.node_data.model,
|
||||
)
|
||||
model_name = getattr(model_instance, "model_name", None)
|
||||
if not isinstance(model_name, str):
|
||||
model_name = model_config.model
|
||||
model_provider = getattr(model_instance, "provider", None)
|
||||
if not isinstance(model_provider, str):
|
||||
model_provider = model_config.provider
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_name}")
|
||||
model_instance = self._model_instance
|
||||
model_name = model_instance.model_name
|
||||
model_provider = model_instance.provider
|
||||
model_stop = model_instance.stop
|
||||
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
@ -240,9 +230,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=self.node_data.model.completion_params,
|
||||
stop=model_config.stop,
|
||||
stop=model_stop,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
@ -254,7 +242,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
@ -371,7 +358,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
@staticmethod
|
||||
def invoke_llm(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None = None,
|
||||
@ -384,11 +370,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
||||
model_parameters = model_instance.parameters
|
||||
invoke_model_parameters = dict(model_parameters)
|
||||
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if structured_output_enabled:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
@ -402,7 +387,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=output_schema,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
model_parameters=invoke_model_parameters,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
@ -412,7 +397,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
model_parameters=invoke_model_parameters,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
@ -771,23 +756,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(
|
||||
self,
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model, model_config_with_cred = llm_utils.fetch_model_config(
|
||||
node_data_model=node_data_model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
completion_params = model_config_with_cred.parameters
|
||||
|
||||
model_config_with_cred.parameters = completion_params
|
||||
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
|
||||
node_data_model.completion_params = completion_params
|
||||
return model, model_config_with_cred
|
||||
|
||||
@staticmethod
|
||||
def fetch_prompt_messages(
|
||||
*,
|
||||
@ -796,8 +764,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
stop: Sequence[str] | None = None,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@ -808,6 +774,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context_files: list[File] | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
@ -826,8 +793,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
# Extend prompt_messages with memory messages
|
||||
prompt_messages.extend(memory_messages)
|
||||
@ -865,8 +830,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
@ -1316,23 +1279,23 @@ def _calculate_rest_token(
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
runtime_model_parameters = model_instance.parameters
|
||||
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
for parameter_rule in runtime_model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_parameters.get(parameter_rule.name)
|
||||
or model_parameters.get(str(parameter_rule.use_template))
|
||||
runtime_model_parameters.get(parameter_rule.name)
|
||||
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
@ -1347,8 +1310,6 @@ def _handle_memory_chat_mode(
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
@ -1356,8 +1317,6 @@ def _handle_memory_chat_mode(
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
@ -1371,8 +1330,6 @@ def _handle_memory_completion_mode(
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
@ -1380,8 +1337,6 @@ def _handle_memory_completion_mode(
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
|
||||
Reference in New Issue
Block a user