fix bugs and add unit tests

This commit is contained in:
takatost
2024-02-22 15:15:42 +08:00
parent 297b33aa41
commit a44d3c3eda
11 changed files with 295 additions and 21 deletions

View File

@ -45,6 +45,7 @@ class SimplePromptTransform(PromptTransform):
"""
Simple Prompt Transform for Chatbot App Basic Mode.
"""
def get_prompt(self,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
@ -154,12 +155,12 @@ class SimplePromptTransform(PromptTransform):
}
def _get_chat_model_prompt_messages(self, pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
files: list[FileObj],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) \
inputs: dict,
query: str,
context: Optional[str],
files: list[FileObj],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
prompt_messages = []
@ -169,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
model_config=model_config,
pre_prompt=pre_prompt,
inputs=inputs,
query=query,
query=None,
context=context
)
@ -187,12 +188,12 @@ class SimplePromptTransform(PromptTransform):
return prompt_messages, None
def _get_completion_model_prompt_messages(self, pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
files: list[FileObj],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) \
inputs: dict,
query: str,
context: Optional[str],
files: list[FileObj],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
# get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules(
@ -259,7 +260,7 @@ class SimplePromptTransform(PromptTransform):
provider=provider,
model=model
)
# Check if the prompt file is already loaded
if prompt_file_name in prompt_file_contents:
return prompt_file_contents[prompt_file_name]
@ -267,14 +268,16 @@ class SimplePromptTransform(PromptTransform):
# Get the absolute path of the subdirectory
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json')
# Open the JSON file and read its content
with open(json_file_path, encoding='utf-8') as json_file:
content = json.load(json_file)
# Store the content of the prompt file
prompt_file_contents[prompt_file_name] = content
return content
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
# baichuan
is_baichuan = False