This commit is contained in:
takatost
2024-03-21 15:02:55 +08:00
parent a05fcedd61
commit d71eae8f93
3 changed files with 35 additions and 223 deletions

View File

@ -15,12 +15,13 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
@ -64,10 +65,10 @@ class LLMNode(BaseNode):
node_inputs['#context#'] = context
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data, variable_pool, model_instance)
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
@ -89,7 +90,7 @@ class LLMNode(BaseNode):
# handle invoke result
result_text, usage = self._invoke_llm(
node_data=node_data,
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
@ -119,13 +120,13 @@ class LLMNode(BaseNode):
}
)
def _invoke_llm(self, node_data: LLMNodeData,
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data: node data
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
@ -135,7 +136,7 @@ class LLMNode(BaseNode):
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data.model.completion_params,
model_parameters=node_data_model.completion_params,
stop=stop,
stream=True,
user=self.user_id,
@ -286,14 +287,14 @@ class LLMNode(BaseNode):
return None
def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
:param node_data_model: node data model
:return:
"""
model_name = node_data.model.name
provider_name = node_data.model.provider
model_name = node_data_model.name
provider_name = node_data_model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
@ -326,14 +327,14 @@ class LLMNode(BaseNode):
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data.model.completion_params
completion_params = node_data_model.completion_params
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = node_data.model.mode
model_mode = node_data_model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
@ -356,26 +357,25 @@ class LLMNode(BaseNode):
stop=stop,
)
def _fetch_memory(self, node_data: LLMNodeData,
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
variable_pool: VariablePool,
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data: node data
:param node_data_memory: node data memory
:param variable_pool: variable pool
:return:
"""
if not node_data.memory:
if not node_data_memory:
return None
# get conversation id
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION])
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
if conversation_id is None:
return None
# get conversation
conversation = db.session.query(Conversation).filter(
Conversation.tenant_id == self.tenant_id,
Conversation.app_id == self.app_id,
Conversation.id == conversation_id
).first()