diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 32f29a1c7c..139ad7e7e1 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -4,6 +4,8 @@ from typing import Optional, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole +from core.model_runtime.entities.model_entities import ModelPropertyKey +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate @@ -129,7 +131,8 @@ class QuestionClassifierNode(LLMNode): :return: """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - prompt_template = self._get_prompt_template(node_data, query, memory) + rest_token = self._calculate_rest_token(node_data, query, model_config, context) + prompt_template = self._get_prompt_template(node_data, query, memory, rest_token) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, @@ -144,8 +147,49 @@ class QuestionClassifierNode(LLMNode): return prompt_messages, stop + def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str]) -> int: + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + prompt_template = self._get_prompt_template(node_data, query, None, 2000) + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs={}, + query='', + files=[], + context=context, + memory_config=node_data.memory, + memory=None, + model_config=model_config + ) + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_type_instance = model_config.provider_model_bundle.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + curr_message_tokens = model_type_instance.get_num_tokens( + model_config.model, + model_config.credentials, + prompt_messages + ) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if (parameter_rule.name == 'max_tokens' + or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): + max_tokens = (model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template)) or 0 + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str, - memory: Optional[TokenBufferMemory]) \ + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000) \ -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes @@ -155,7 +199,7 @@ class QuestionClassifierNode(LLMNode): input_text = query memory_str = '' if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=2000, + memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, message_limit=node_data.memory.window.size) prompt_messages = [] if model_mode == ModelMode.CHAT: