mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
fix bugs and add unit tests
This commit is contained in:
0
api/tests/unit_tests/core/__init__.py
Normal file
0
api/tests/unit_tests/core/__init__.py
Normal file
0
api/tests/unit_tests/core/prompt/__init__.py
Normal file
0
api/tests/unit_tests/core/prompt/__init__.py
Normal file
47
api/tests/unit_tests/core/prompt/test_prompt_transform.py
Normal file
47
api/tests/unit_tests/core/prompt/test_prompt_transform.py
Normal file
@ -0,0 +1,47 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, AIModelEntity, ParameterRule
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
def test__calculate_rest_token():
|
||||
model_schema_mock = MagicMock(spec=AIModelEntity)
|
||||
parameter_rule_mock = MagicMock(spec=ParameterRule)
|
||||
parameter_rule_mock.name = 'max_tokens'
|
||||
model_schema_mock.parameter_rules = [
|
||||
parameter_rule_mock
|
||||
]
|
||||
model_schema_mock.model_properties = {
|
||||
ModelPropertyKey.CONTEXT_SIZE: 62
|
||||
}
|
||||
|
||||
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
large_language_model_mock.get_num_tokens.return_value = 6
|
||||
|
||||
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
|
||||
provider_model_bundle_mock.model_type_instance = large_language_model_mock
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.model = 'gpt-4'
|
||||
model_config_mock.credentials = {}
|
||||
model_config_mock.parameters = {
|
||||
'max_tokens': 50
|
||||
}
|
||||
model_config_mock.model_schema = model_schema_mock
|
||||
model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||
|
||||
prompt_transform = PromptTransform()
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Hello, how are you?")]
|
||||
rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
|
||||
|
||||
# Validate based on the mock configuration and expected logic
|
||||
expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
- model_config_mock.parameters['max_tokens']
|
||||
- large_language_model_mock.get_num_tokens.return_value)
|
||||
assert rest_tokens == expected_rest_tokens
|
||||
assert rest_tokens == 6
|
||||
216
api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py
Normal file
216
api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py
Normal file
@ -0,0 +1,216 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_pcqm():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "You are a helpful assistant."
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['histories_prompt']
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
|
||||
|
||||
|
||||
def test_get_baichuan_chat_app_prompt_template_with_pcqm():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "You are a helpful assistant."
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider="baichuan",
|
||||
model="Baichuan2-53B",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['histories_prompt']
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
|
||||
|
||||
|
||||
def test_get_common_completion_app_prompt_template_with_pcq():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "You are a helpful assistant."
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
|
||||
|
||||
|
||||
def test_get_baichuan_completion_app_prompt_template_with_pcq():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "You are a helpful assistant."
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
provider="baichuan",
|
||||
model="Baichuan2-53B",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
print(prompt_template['prompt_template'].template)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_q():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = ""
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=False,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == prompt_rules['query_prompt']
|
||||
assert prompt_template['special_variable_keys'] == ['#query#']
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_cq():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = ""
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_p():
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "you are {{name}}"
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=False,
|
||||
query_in_prompt=False,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
assert prompt_template['prompt_template'].template == pre_prompt + '\n'
|
||||
assert prompt_template['custom_variable_keys'] == ['name']
|
||||
assert prompt_template['special_variable_keys'] == []
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-4'
|
||||
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "You are a helpful assistant {{name}}."
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
context = "yes or no."
|
||||
query = "How are you?"
|
||||
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=[],
|
||||
context=context,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider=model_config_mock.provider,
|
||||
model=model_config_mock.model,
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=False,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
|
||||
full_inputs = {**inputs, '#context#': context}
|
||||
real_system_prompt = prompt_template['prompt_template'].format(full_inputs)
|
||||
|
||||
assert len(prompt_messages) == 2
|
||||
assert prompt_messages[0].content == real_system_prompt
|
||||
assert prompt_messages[1].content == query
|
||||
|
||||
|
||||
def test__get_completion_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-3.5-turbo-instruct'
|
||||
|
||||
prompt_transform = SimplePromptTransform()
|
||||
pre_prompt = "You are a helpful assistant {{name}}."
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
context = "yes or no."
|
||||
query = "How are you?"
|
||||
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=[],
|
||||
context=context,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
)
|
||||
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
app_mode=AppMode.CHAT,
|
||||
provider=model_config_mock.provider,
|
||||
model=model_config_mock.model,
|
||||
pre_prompt=pre_prompt,
|
||||
has_context=True,
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
|
||||
full_inputs = {**inputs, '#context#': context, '#query#': query}
|
||||
real_prompt = prompt_template['prompt_template'].format(full_inputs)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert stops == prompt_template['prompt_rules'].get('stops')
|
||||
assert prompt_messages[0].content == real_prompt
|
||||
Reference in New Issue
Block a user