mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
add llm node
This commit is contained in:
@ -2,12 +2,12 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \
|
||||
ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity, FileUploadEntity
|
||||
from core.app.app_config.entities import ModelConfigEntity, FileUploadEntity
|
||||
from core.file.file_obj import FileObj, FileType, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig, ChatModelMessage
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.model import Conversation
|
||||
|
||||
@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages():
|
||||
model_config_mock.model = 'gpt-3.5-turbo-instruct'
|
||||
|
||||
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt=prompt_template,
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(
|
||||
user="Human",
|
||||
assistant="Assistant"
|
||||
)
|
||||
prompt_template_config = CompletionModelPromptTemplate(
|
||||
text=prompt_template
|
||||
)
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(
|
||||
user="Human",
|
||||
assistant="Assistant"
|
||||
),
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages():
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_completion_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=prompt_template_config,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages():
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
|
||||
model_config_mock, memory_config, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = []
|
||||
query = "Hi2."
|
||||
@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
assert len(prompt_messages) == 6
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert prompt_messages[5].content == query
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
||||
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = []
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
||||
assert len(prompt_messages) == 3
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
|
||||
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = [
|
||||
FileObj(
|
||||
@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert isinstance(prompt_messages[3].content, list)
|
||||
assert len(prompt_messages[3].content) == 2
|
||||
@ -173,22 +181,31 @@ def get_chat_model_args():
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-4'
|
||||
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[
|
||||
AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM),
|
||||
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
|
||||
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
memory_config = MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
)
|
||||
|
||||
prompt_messages = [
|
||||
ChatModelMessage(
|
||||
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM
|
||||
),
|
||||
ChatModelMessage(
|
||||
text="Hi.",
|
||||
role=PromptMessageRole.USER
|
||||
),
|
||||
ChatModelMessage(
|
||||
text="Hello!",
|
||||
role=PromptMessageRole.ASSISTANT
|
||||
)
|
||||
]
|
||||
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
|
||||
context = "I am superman."
|
||||
|
||||
return model_config_mock, prompt_template_entity, inputs, context
|
||||
return model_config_mock, memory_config, prompt_messages, inputs, context
|
||||
|
||||
Reference in New Issue
Block a user