mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
refactor(workflow): inject credential/model access ports into LLM nodes (#32569)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -8,7 +8,7 @@ from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@ -17,6 +17,8 @@ from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegme
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -24,49 +26,46 @@ from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
|
||||
from .exc import InvalidVariableTypeError
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
tenant_id: str, node_data_model: ModelConfig
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=node_data_model.provider,
|
||||
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
||||
|
||||
# check model
|
||||
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
# model config
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
return model, ModelConfigWithCredentialsEntity(
|
||||
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=model.provider_model_bundle,
|
||||
credentials=model.credentials,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
@ -131,7 +130,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model)
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
@ -38,11 +38,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
@ -76,6 +72,7 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import SegmentAttachmentBinding
|
||||
@ -93,7 +90,6 @@ from .exc import (
|
||||
InvalidVariableTypeError,
|
||||
LLMNodeError,
|
||||
MemoryRolePrefixRequiredError,
|
||||
ModelNotExistError,
|
||||
NoPromptFoundError,
|
||||
TemplateTypeNotSupportError,
|
||||
VariableNotFoundError,
|
||||
@ -118,6 +114,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
_file_outputs: list[File]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: CredentialsProvider
|
||||
_model_factory: ModelFactory
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -126,6 +124,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -137,6 +137,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
@ -199,10 +202,21 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
model_instance, model_config = self._fetch_model_config(
|
||||
node_data_model=self.node_data.model,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
model_name = getattr(model_instance, "model_name", None)
|
||||
if not isinstance(model_name, str):
|
||||
model_name = model_config.model
|
||||
model_provider = getattr(model_instance, "provider", None)
|
||||
if not isinstance(model_provider, str):
|
||||
model_provider = model_config.provider
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_name}")
|
||||
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
@ -225,14 +239,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=self.node_data.model.completion_params,
|
||||
stop=model_config.stop,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
@ -286,14 +302,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
structured_output = event
|
||||
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"model_mode": self.node_data.model.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
model_mode=self.node_data.model.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
"model_provider": model_provider,
|
||||
"model_name": model_name,
|
||||
}
|
||||
|
||||
outputs = {
|
||||
@ -755,21 +771,18 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _fetch_model_config(
|
||||
self,
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
tenant_id: str,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model, model_config_with_cred = llm_utils.fetch_model_config(
|
||||
tenant_id=tenant_id, node_data_model=node_data_model
|
||||
node_data_model=node_data_model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
completion_params = model_config_with_cred.parameters
|
||||
|
||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_config_with_cred.parameters = completion_params
|
||||
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
|
||||
node_data_model.completion_params = completion_params
|
||||
@ -782,14 +795,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
stop: Sequence[str] | None = None,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
vision_enabled: bool = False,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
tenant_id: str,
|
||||
context_files: list[File] | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
@ -810,7 +825,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
# Extend prompt_messages with memory messages
|
||||
prompt_messages.extend(memory_messages)
|
||||
@ -847,7 +864,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
@ -924,7 +943,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in prompt_message.content:
|
||||
# Skip content if features are not defined
|
||||
if not model_config.model_schema.features:
|
||||
if not model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
@ -934,19 +953,19 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if (
|
||||
(
|
||||
content_item.type == PromptMessageContentType.IMAGE
|
||||
and ModelFeature.VISION not in model_config.model_schema.features
|
||||
and ModelFeature.VISION not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.DOCUMENT
|
||||
and ModelFeature.DOCUMENT not in model_config.model_schema.features
|
||||
and ModelFeature.DOCUMENT not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.VIDEO
|
||||
and ModelFeature.VIDEO not in model_config.model_schema.features
|
||||
and ModelFeature.VIDEO not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.AUDIO
|
||||
and ModelFeature.AUDIO not in model_config.model_schema.features
|
||||
and ModelFeature.AUDIO not in model_schema.features
|
||||
)
|
||||
):
|
||||
continue
|
||||
@ -965,19 +984,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model,
|
||||
)
|
||||
model_schema = model.model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_config.model} not exist.")
|
||||
return filtered_prompt_messages, model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@ -1306,26 +1313,26 @@ def _render_jinja2_message(
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(str(parameter_rule.use_template))
|
||||
model_parameters.get(parameter_rule.name)
|
||||
or model_parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
@ -1339,12 +1346,19 @@ def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
@ -1356,12 +1370,19 @@ def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = memory.get_history_prompt_text(
|
||||
|
||||
21
api/core/workflow/nodes/llm/protocols.py
Normal file
21
api/core/workflow/nodes/llm/protocols.py
Normal file
@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
"""Port for loading runtime credentials for a provider/model pair."""
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
"""Return credentials for the target provider/model or raise a domain error."""
|
||||
...
|
||||
|
||||
|
||||
class ModelFactory(Protocol):
|
||||
"""Port for creating initialized LLM model instances for execution."""
|
||||
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||
...
|
||||
@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -60,6 +60,11 @@ from .prompts import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def extract_json(text):
|
||||
"""
|
||||
@ -92,6 +97,27 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
_model_instance: ModelInstance | None = None
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -806,7 +832,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
if not self._model_instance or not self._model_config:
|
||||
self._model_instance, self._model_config = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id, node_data_model=node_data_model
|
||||
node_data_model=node_data_model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
|
||||
return self._model_instance, self._model_config
|
||||
|
||||
@ -24,6 +24,7 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
@ -49,6 +50,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
|
||||
_file_outputs: list["File"]
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -57,6 +60,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -68,6 +73,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
@ -89,9 +97,16 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
variables = {"query": query}
|
||||
# fetch model config
|
||||
model_instance, model_config = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_model=node_data.model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_instance.model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
@ -133,13 +148,15 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=node_data.model.completion_params,
|
||||
stop=model_config.stop,
|
||||
sys_files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
|
||||
Reference in New Issue
Block a user