merge feat/plugins

This commit is contained in:
Joel
2024-11-08 13:57:34 +08:00
133 changed files with 3602 additions and 1774 deletions

View File

@ -49,11 +49,13 @@ class QuestionClassifierNode(LLMNode):
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
variable = variable_pool.get(
node_data.query_variable_selector) if node_data.query_variable_selector else None
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
model_instance, model_config = self._fetch_model_config(
node_data.model)
# fetch memory
memory = self._fetch_memory(
node_data_memory=node_data.memory,
@ -61,7 +63,8 @@ class QuestionClassifierNode(LLMNode):
)
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
node_data.instruction = variable_pool.convert_template(
node_data.instruction).text
files: Sequence[File] = (
self._fetch_files(
@ -184,12 +187,15 @@ class QuestionClassifierNode(LLMNode):
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
variable_template_parser = VariableTemplateParser(
template=node_data.instruction)
variable_selectors.extend(
variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
variable_mapping = {node_id + "." + key: value for key,
value in variable_mapping.items()}
return variable_mapping
@ -210,7 +216,8 @@ class QuestionClassifierNode(LLMNode):
context: Optional[str],
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_template = self._get_prompt_template(
node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
@ -223,13 +230,15 @@ class QuestionClassifierNode(LLMNode):
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_config.model_schema.model_properties.get(
ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
curr_message_tokens = model_instance.get_llm_num_tokens(
prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
@ -270,7 +279,8 @@ class QuestionClassifierNode(LLMNode):
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(
histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
@ -311,4 +321,5 @@ class QuestionClassifierNode(LLMNode):
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
raise InvalidModelTypeError(
f"Model mode {model_mode} not support.")