mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
fix bugs and add unit tests
This commit is contained in:
@ -45,6 +45,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
"""
|
||||
Simple Prompt Transform for Chatbot App Basic Mode.
|
||||
"""
|
||||
|
||||
def get_prompt(self,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
@ -154,12 +155,12 @@ class SimplePromptTransform(PromptTransform):
|
||||
}
|
||||
|
||||
def _get_chat_model_prompt_messages(self, pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list[FileObj],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigEntity) \
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list[FileObj],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
prompt_messages = []
|
||||
|
||||
@ -169,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
model_config=model_config,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
query=None,
|
||||
context=context
|
||||
)
|
||||
|
||||
@ -187,12 +188,12 @@ class SimplePromptTransform(PromptTransform):
|
||||
return prompt_messages, None
|
||||
|
||||
def _get_completion_model_prompt_messages(self, pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list[FileObj],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigEntity) \
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list[FileObj],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
# get prompt
|
||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
||||
@ -259,7 +260,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
|
||||
|
||||
# Check if the prompt file is already loaded
|
||||
if prompt_file_name in prompt_file_contents:
|
||||
return prompt_file_contents[prompt_file_name]
|
||||
@ -267,14 +268,16 @@ class SimplePromptTransform(PromptTransform):
|
||||
# Get the absolute path of the subdirectory
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts')
|
||||
json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json')
|
||||
|
||||
|
||||
# Open the JSON file and read its content
|
||||
with open(json_file_path, encoding='utf-8') as json_file:
|
||||
content = json.load(json_file)
|
||||
|
||||
|
||||
# Store the content of the prompt file
|
||||
prompt_file_contents[prompt_file_name] = content
|
||||
|
||||
return content
|
||||
|
||||
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
|
||||
# baichuan
|
||||
is_baichuan = False
|
||||
|
||||
Reference in New Issue
Block a user