mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
fix qc
This commit is contained in:
@ -15,12 +15,13 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
@ -64,10 +65,10 @@ class LLMNode(BaseNode):
|
||||
node_inputs['#context#'] = context
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data)
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data, variable_pool, model_instance)
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
@ -89,7 +90,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._invoke_llm(
|
||||
node_data=node_data,
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
@ -119,13 +120,13 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data: LLMNodeData,
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str]) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data: node data
|
||||
:param node_data_model: node data model
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
@ -135,7 +136,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data.model.completion_params,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
@ -286,14 +287,14 @@ class LLMNode(BaseNode):
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data: node data
|
||||
:param node_data_model: node data model
|
||||
:return:
|
||||
"""
|
||||
model_name = node_data.model.name
|
||||
provider_name = node_data.model.provider
|
||||
model_name = node_data_model.name
|
||||
provider_name = node_data_model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
@ -326,14 +327,14 @@ class LLMNode(BaseNode):
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = node_data.model.completion_params
|
||||
completion_params = node_data_model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data.model.mode
|
||||
model_mode = node_data_model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
@ -356,26 +357,25 @@ class LLMNode(BaseNode):
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _fetch_memory(self, node_data: LLMNodeData,
|
||||
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
|
||||
"""
|
||||
Fetch memory
|
||||
:param node_data: node data
|
||||
:param node_data_memory: node data memory
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
if not node_data.memory:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION])
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
# get conversation
|
||||
conversation = db.session.query(Conversation).filter(
|
||||
Conversation.tenant_id == self.tenant_id,
|
||||
Conversation.app_id == self.app_id,
|
||||
Conversation.id == conversation_id
|
||||
).first()
|
||||
|
||||
Reference in New Issue
Block a user