llm and answer node support inner variable template

This commit is contained in:
takatost
2024-03-29 18:44:21 +08:00
parent 8a2d04b305
commit 971436d935
13 changed files with 172 additions and 135 deletions

View File

@ -4,7 +4,6 @@ from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class ModelConfig(BaseModel):
@ -44,7 +43,6 @@ class LLMNodeData(BaseNodeData):
LLM Node Data.
"""
model: ModelConfig
variables: list[VariableSelector] = []
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
memory: Optional[MemoryConfig] = None
context: ContextConfig

View File

@ -15,13 +15,14 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
@ -48,9 +49,7 @@ class LLMNode(BaseNode):
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
node_inputs = {
**inputs
}
node_inputs = {}
# fetch files
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
@ -192,10 +191,21 @@ class LLMNode(BaseNode):
:return:
"""
inputs = {}
for variable_selector in node_data.variables:
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
elif isinstance(prompt_template, CompletionModelPromptTemplate):
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.value_selector} not found')
raise ValueError(f'Variable {variable_selector.variable} not found')
inputs[variable_selector.variable] = variable_value
@ -411,7 +421,7 @@ class LLMNode(BaseNode):
:param model_config: model config
:return:
"""
prompt_transform = AdvancedPromptTransform()
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template,
inputs=inputs,
@ -486,9 +496,6 @@ class LLMNode(BaseNode):
node_data = cast(cls._node_data_cls, node_data)
variable_mapping = {}
for variable_selector in node_data.variables:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping['#context#'] = node_data.context.variable_selector