mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
llm and answer node support inner variable template
This commit is contained in:
@ -4,7 +4,6 @@ from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
@ -44,7 +43,6 @@ class LLMNodeData(BaseNodeData):
|
||||
LLM Node Data.
|
||||
"""
|
||||
model: ModelConfig
|
||||
variables: list[VariableSelector] = []
|
||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
|
||||
@ -15,13 +15,14 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
@ -48,9 +49,7 @@ class LLMNode(BaseNode):
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data, variable_pool)
|
||||
|
||||
node_inputs = {
|
||||
**inputs
|
||||
}
|
||||
node_inputs = {}
|
||||
|
||||
# fetch files
|
||||
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
|
||||
@ -192,10 +191,21 @@ class LLMNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
inputs = {}
|
||||
for variable_selector in node_data.variables:
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
elif isinstance(prompt_template, CompletionModelPromptTemplate):
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
for variable_selector in variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.value_selector} not found')
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
@ -411,7 +421,7 @@ class LLMNode(BaseNode):
|
||||
:param model_config: model config
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=node_data.prompt_template,
|
||||
inputs=inputs,
|
||||
@ -486,9 +496,6 @@ class LLMNode(BaseNode):
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in node_data.variables:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if node_data.context.enabled:
|
||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
|
||||
|
||||
Reference in New Issue
Block a user