restore completion app

This commit is contained in:
takatost
2024-02-25 21:30:36 +08:00
parent 9820dcb201
commit 55c31eec31
14 changed files with 224 additions and 30 deletions

View File

@ -22,8 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import SimplePromptTransform
from models.model import App, Message, MessageAnnotation
from models.model import App, Message, MessageAnnotation, AppMode
class AppRunner:
@ -140,11 +141,11 @@ class AppRunner:
:param memory: memory
:return:
"""
prompt_transform = SimplePromptTransform()
# get prompt without memory and context
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
prompt_transform = SimplePromptTransform()
prompt_messages, stop = prompt_transform.get_prompt(
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query if query else '',
@ -154,7 +155,17 @@ class AppRunner:
model_config=model_config
)
else:
raise NotImplementedError("Advanced prompt is not supported yet.")
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query if query else '',
files=files,
context=context,
memory=memory,
model_config=model_config
)
stop = model_config.stop
return prompt_messages, stop

View File

@ -11,10 +11,9 @@ class PromptTransform:
def _append_chat_histories(self, memory: TokenBufferMemory,
prompt_messages: list[PromptMessage],
model_config: ModelConfigEntity) -> list[PromptMessage]:
if memory:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories)
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories)
return prompt_messages

View File

@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform):
"""
def get_prompt(self,
app_mode: AppMode,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
@ -58,6 +59,7 @@ class SimplePromptTransform(PromptTransform):
model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages(
app_mode=app_mode,
pre_prompt=prompt_template_entity.simple_prompt_template,
inputs=inputs,
query=query,
@ -68,6 +70,7 @@ class SimplePromptTransform(PromptTransform):
)
else:
prompt_messages, stops = self._get_completion_model_prompt_messages(
app_mode=app_mode,
pre_prompt=prompt_template_entity.simple_prompt_template,
inputs=inputs,
query=query,
@ -154,7 +157,8 @@ class SimplePromptTransform(PromptTransform):
"prompt_rules": prompt_rules
}
def _get_chat_model_prompt_messages(self, pre_prompt: str,
def _get_chat_model_prompt_messages(self, app_mode: AppMode,
pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
@ -166,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
# get prompt
prompt, _ = self.get_prompt_str_and_rules(
app_mode=AppMode.CHAT,
app_mode=app_mode,
model_config=model_config,
pre_prompt=pre_prompt,
inputs=inputs,
@ -175,19 +179,25 @@ class SimplePromptTransform(PromptTransform):
)
if prompt:
prompt_messages.append(SystemPromptMessage(content=prompt))
if query:
prompt_messages.append(SystemPromptMessage(content=prompt))
else:
prompt_messages.append(UserPromptMessage(content=prompt))
prompt_messages = self._append_chat_histories(
memory=memory,
prompt_messages=prompt_messages,
model_config=model_config
)
if memory:
prompt_messages = self._append_chat_histories(
memory=memory,
prompt_messages=prompt_messages,
model_config=model_config
)
prompt_messages.append(self.get_last_user_message(query, files))
if query:
prompt_messages.append(self.get_last_user_message(query, files))
return prompt_messages, None
def _get_completion_model_prompt_messages(self, pre_prompt: str,
def _get_completion_model_prompt_messages(self, app_mode: AppMode,
pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
@ -197,7 +207,7 @@ class SimplePromptTransform(PromptTransform):
-> tuple[list[PromptMessage], Optional[list[str]]]:
# get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules(
app_mode=AppMode.CHAT,
app_mode=app_mode,
model_config=model_config,
pre_prompt=pre_prompt,
inputs=inputs,
@ -220,7 +230,7 @@ class SimplePromptTransform(PromptTransform):
# get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules(
app_mode=AppMode.CHAT,
app_mode=app_mode,
model_config=model_config,
pre_prompt=pre_prompt,
inputs=inputs,
@ -289,13 +299,13 @@ class SimplePromptTransform(PromptTransform):
is_baichuan = True
if is_baichuan:
if app_mode == AppMode.WORKFLOW:
if app_mode == AppMode.COMPLETION:
return 'baichuan_completion'
else:
return 'baichuan_chat'
# common
if app_mode == AppMode.WORKFLOW:
if app_mode == AppMode.COMPLETION:
return 'common_completion'
else:
return 'common_chat'