Merge remote-tracking branch 'origin/feat/workflow' into feat/workflow

# Conflicts:
#	api/core/workflow/nodes/question_classifier/question_classifier_node.py
This commit is contained in:
jyong
2024-03-29 19:00:21 +08:00
33 changed files with 518 additions and 419 deletions

View File

@ -21,6 +21,8 @@ class AdvancedPromptTransform(PromptTransform):
"""
Advanced Prompt Transform for Workflow LLM Node.
"""
def __init__(self, with_variable_tmpl: bool = False) -> None:
self.with_variable_tmpl = with_variable_tmpl
def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
inputs: dict,
@ -74,7 +76,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_messages = []
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
@ -128,7 +130,7 @@ class AdvancedPromptTransform(PromptTransform):
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item.text
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
@ -211,7 +213,7 @@ class AdvancedPromptTransform(PromptTransform):
if '#histories#' in prompt_template.variable_keys:
if memory:
inputs = {'#histories#': '', **prompt_inputs}
prompt_template = PromptTemplateParser(raw_prompt)
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
tmp_human_message = UserPromptMessage(
content=prompt_template.format(prompt_inputs)

View File

@ -1,6 +1,9 @@
import re
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}")
WITH_VARIABLE_TMPL_REGEX = re.compile(
r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#[a-zA-Z0-9_]{1,50}\.[a-zA-Z0-9_\.]{1,100}#|#histories#|#query#|#context#)\}\}"
)
class PromptTemplateParser:
@ -15,13 +18,15 @@ class PromptTemplateParser:
`{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
"""
def __init__(self, template: str):
def __init__(self, template: str, with_variable_tmpl: bool = False):
self.template = template
self.with_variable_tmpl = with_variable_tmpl
self.regex = WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX
self.variable_keys = self.extract()
def extract(self) -> list:
# Regular expression to match the template rules
return re.findall(REGEX, self.template)
return re.findall(self.regex, self.template)
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
def replacer(match):
@ -29,12 +34,12 @@ class PromptTemplateParser:
value = inputs.get(key, match.group(0)) # return original matched string if key not found
if remove_template_variables:
return PromptTemplateParser.remove_template_variables(value)
return PromptTemplateParser.remove_template_variables(value, self.with_variable_tmpl)
return value
prompt = re.sub(REGEX, replacer, self.template)
prompt = re.sub(self.regex, replacer, self.template)
return re.sub(r'<\|.*?\|>', '', prompt)
@classmethod
def remove_template_variables(cls, text: str):
return re.sub(REGEX, r'{\1}', text)
def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False):
return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r'{\1}', text)

View File

@ -13,6 +13,7 @@ from core.workflow.nodes.answer.entities import (
VarGenerateRouteChunk,
)
from core.workflow.nodes.base_node import BaseNode
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
@ -66,32 +67,8 @@ class AnswerNode(BaseNode):
part = cast(TextGenerateRouteChunk, part)
answer += part.text
# re-fetch variable values
variable_values = {}
for variable_selector in node_data.variables:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
if isinstance(value, str | int | float):
value = str(value)
elif isinstance(value, FileVar):
value = value.to_dict()
elif isinstance(value, list):
new_value = []
for item in value:
if isinstance(item, FileVar):
new_value.append(item.to_dict())
else:
new_value.append(item)
value = new_value
variable_values[variable_selector.variable] = value
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variable_values,
outputs={
"answer": answer
}
@ -116,15 +93,18 @@ class AnswerNode(BaseNode):
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(node_data.answer)
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
@ -164,8 +144,11 @@ class AnswerNode(BaseNode):
"""
node_data = cast(cls._node_data_cls, node_data)
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
for variable_selector in node_data.variables:
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
return variable_mapping

View File

@ -2,14 +2,12 @@
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class AnswerNodeData(BaseNodeData):
"""
Answer Node Data.
"""
variables: list[VariableSelector] = []
answer: str

View File

@ -4,7 +4,6 @@ from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class ModelConfig(BaseModel):
@ -44,7 +43,6 @@ class LLMNodeData(BaseNodeData):
LLM Node Data.
"""
model: ModelConfig
variables: list[VariableSelector] = []
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
memory: Optional[MemoryConfig] = None
context: ContextConfig

View File

@ -15,13 +15,14 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
@ -48,9 +49,7 @@ class LLMNode(BaseNode):
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
node_inputs = {
**inputs
}
node_inputs = {}
# fetch files
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
@ -192,10 +191,21 @@ class LLMNode(BaseNode):
:return:
"""
inputs = {}
for variable_selector in node_data.variables:
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
elif isinstance(prompt_template, CompletionModelPromptTemplate):
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.value_selector} not found')
raise ValueError(f'Variable {variable_selector.variable} not found')
inputs[variable_selector.variable] = variable_value
@ -411,7 +421,7 @@ class LLMNode(BaseNode):
:param model_config: model config
:return:
"""
prompt_transform = AdvancedPromptTransform()
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template,
inputs=inputs,
@ -486,9 +496,6 @@ class LLMNode(BaseNode):
node_data = cast(cls._node_data_cls, node_data)
variable_mapping = {}
for variable_selector in node_data.variables:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping['#context#'] = node_data.context.variable_selector

View File

@ -128,7 +128,7 @@ class QuestionClassifierNode(LLMNode):
:param model_config: model config
:return:
"""
prompt_transform = AdvancedPromptTransform()
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, memory)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,

View File

View File

@ -0,0 +1,58 @@
import re
from core.workflow.entities.variable_entities import VariableSelector
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
class VariableTemplateParser:
"""
Rules:
1. Template variables must be enclosed in `{{}}`.
2. The template variable Key can only be: #node_id.var1.var2#.
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
"""
def __init__(self, template: str):
self.template = template
self.variable_keys = self.extract()
def extract(self) -> list:
# Regular expression to match the template rules
matches = re.findall(REGEX, self.template)
first_group_matches = [match[0] for match in matches]
return list(set(first_group_matches))
def extract_variable_selectors(self) -> list[VariableSelector]:
variable_selectors = []
for variable_key in self.variable_keys:
remove_hash = variable_key.replace('#', '')
split_result = remove_hash.split('.')
if len(split_result) < 2:
continue
variable_selectors.append(VariableSelector(
variable=variable_key,
value_selector=split_result
))
return variable_selectors
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
if remove_template_variables:
return VariableTemplateParser.remove_template_variables(value)
return value
prompt = re.sub(REGEX, replacer, self.template)
return re.sub(r'<\|.*?\|>', '', prompt)
@classmethod
def remove_template_variables(cls, text: str):
return re.sub(REGEX, r'{\1}', text)