add llm node

This commit is contained in:
takatost
2024-03-12 22:12:03 +08:00
parent 4f5c052dc8
commit 3f59a579d7
17 changed files with 697 additions and 182 deletions

View File

@ -2,12 +2,12 @@ from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \
ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity, FileUploadEntity
from core.app.app_config.entities import ModelConfigEntity, FileUploadEntity
from core.file.file_obj import FileObj, FileType, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig, ChatModelMessage
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import Conversation
@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages():
model_config_mock.model = 'gpt-3.5-turbo-instruct'
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
prompt_template_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt=prompt_template,
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(
user="Human",
assistant="Assistant"
)
prompt_template_config = CompletionModelPromptTemplate(
text=prompt_template
)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(
user="Human",
assistant="Assistant"
),
window=MemoryConfig.WindowConfig(
enabled=False
)
)
inputs = {
"name": "John"
}
@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages():
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_completion_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
prompt_template=prompt_template_config,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock
)
@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages():
def test__get_chat_model_prompt_messages(get_chat_model_args):
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
model_config_mock, memory_config, messages, inputs, context = get_chat_model_args
files = []
query = "Hi2."
@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
prompt_template=messages,
inputs=inputs,
query=query,
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock
)
@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
assert len(prompt_messages) == 6
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
template=messages[0].text
).format({**inputs, "#context#": context})
assert prompt_messages[5].content == query
def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
model_config_mock, _, messages, inputs, context = get_chat_model_args
files = []
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
prompt_template=messages,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock
)
@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
assert len(prompt_messages) == 3
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
template=messages[0].text
).format({**inputs, "#context#": context})
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
model_config_mock, _, messages, inputs, context = get_chat_model_args
files = [
FileObj(
@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
prompt_template=messages,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock
)
@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
assert len(prompt_messages) == 4
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
template=messages[0].text
).format({**inputs, "#context#": context})
assert isinstance(prompt_messages[3].content, list)
assert len(prompt_messages[3].content) == 2
@ -173,22 +181,31 @@ def get_chat_model_args():
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4'
prompt_template_entity = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
messages=[
AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM),
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
memory_config = MemoryConfig(
window=MemoryConfig.WindowConfig(
enabled=False
)
)
prompt_messages = [
ChatModelMessage(
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM
),
ChatModelMessage(
text="Hi.",
role=PromptMessageRole.USER
),
ChatModelMessage(
text="Hello!",
role=PromptMessageRole.ASSISTANT
)
]
inputs = {
"name": "John"
}
context = "I am superman."
return model_config_mock, prompt_template_entity, inputs, context
return model_config_mock, memory_config, prompt_messages, inputs, context