Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly
2024-11-26 18:04:34 +08:00
26 changed files with 241 additions and 103 deletions

View File

@ -20,6 +20,7 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContent,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
@ -66,7 +67,6 @@ from .entities import (
ModelConfig,
)
from .exc import (
FileTypeNotSupportError,
InvalidContextStructureError,
InvalidVariableTypeError,
LLMModeRequiredError,
@ -137,12 +137,12 @@ class LLMNode(BaseNode[LLMNodeData]):
query = None
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if query is None and (
query_variable := self.graph_runtime_state.variable_pool.get(
(SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)
)
):
query = query_variable.text
if not query and (
query_variable := self.graph_runtime_state.variable_pool.get(
(SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)
)
):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages(
user_query=query,
@ -675,7 +675,7 @@ class LLMNode(BaseNode[LLMNodeData]):
and ModelFeature.AUDIO not in model_config.model_schema.features
)
):
raise FileTypeNotSupportError(type_name=content_item.type)
continue
prompt_message_content.append(content_item)
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
@ -828,14 +828,14 @@ class LLMNode(BaseNode[LLMNodeData]):
}
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
return UserPromptMessage(content=contents)
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
return SystemPromptMessage(content=contents)
raise NotImplementedError(f"Role {role} is not supported")
@ -877,7 +877,9 @@ def _handle_list_messages(
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
@ -908,12 +910,14 @@ def _handle_list_messages(
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = UserPromptMessage(content=file_contents)
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
@ -1018,6 +1022,8 @@ def _handle_completion_template(
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
)
prompt_messages.append(prompt_message)
return prompt_messages