mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
restore completion app
This commit is contained in:
@ -22,8 +22,9 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from models.model import App, Message, MessageAnnotation
|
||||
from models.model import App, Message, MessageAnnotation, AppMode
|
||||
|
||||
|
||||
class AppRunner:
|
||||
@ -140,11 +141,11 @@ class AppRunner:
|
||||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = SimplePromptTransform()
|
||||
|
||||
# get prompt without memory and context
|
||||
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_messages, stop = prompt_transform.get_prompt(
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
@ -154,7 +155,17 @@ class AppRunner:
|
||||
model_config=model_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Advanced prompt is not supported yet.")
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
|
||||
@ -11,10 +11,9 @@ class PromptTransform:
|
||||
def _append_chat_histories(self, memory: TokenBufferMemory,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigEntity) -> list[PromptMessage]:
|
||||
if memory:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
"""
|
||||
|
||||
def get_prompt(self,
|
||||
app_mode: AppMode,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
@ -58,6 +59,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
||||
app_mode=app_mode,
|
||||
pre_prompt=prompt_template_entity.simple_prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
@ -68,6 +70,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
)
|
||||
else:
|
||||
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
||||
app_mode=app_mode,
|
||||
pre_prompt=prompt_template_entity.simple_prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
@ -154,7 +157,8 @@ class SimplePromptTransform(PromptTransform):
|
||||
"prompt_rules": prompt_rules
|
||||
}
|
||||
|
||||
def _get_chat_model_prompt_messages(self, pre_prompt: str,
|
||||
def _get_chat_model_prompt_messages(self, app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
@ -166,7 +170,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
# get prompt
|
||||
prompt, _ = self.get_prompt_str_and_rules(
|
||||
app_mode=AppMode.CHAT,
|
||||
app_mode=app_mode,
|
||||
model_config=model_config,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
@ -175,19 +179,25 @@ class SimplePromptTransform(PromptTransform):
|
||||
)
|
||||
|
||||
if prompt:
|
||||
prompt_messages.append(SystemPromptMessage(content=prompt))
|
||||
if query:
|
||||
prompt_messages.append(SystemPromptMessage(content=prompt))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
prompt_messages = self._append_chat_histories(
|
||||
memory=memory,
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=model_config
|
||||
)
|
||||
if memory:
|
||||
prompt_messages = self._append_chat_histories(
|
||||
memory=memory,
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
prompt_messages.append(self.get_last_user_message(query, files))
|
||||
if query:
|
||||
prompt_messages.append(self.get_last_user_message(query, files))
|
||||
|
||||
return prompt_messages, None
|
||||
|
||||
def _get_completion_model_prompt_messages(self, pre_prompt: str,
|
||||
def _get_completion_model_prompt_messages(self, app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
@ -197,7 +207,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
# get prompt
|
||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
||||
app_mode=AppMode.CHAT,
|
||||
app_mode=app_mode,
|
||||
model_config=model_config,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
@ -220,7 +230,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
# get prompt
|
||||
prompt, prompt_rules = self.get_prompt_str_and_rules(
|
||||
app_mode=AppMode.CHAT,
|
||||
app_mode=app_mode,
|
||||
model_config=model_config,
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=inputs,
|
||||
@ -289,13 +299,13 @@ class SimplePromptTransform(PromptTransform):
|
||||
is_baichuan = True
|
||||
|
||||
if is_baichuan:
|
||||
if app_mode == AppMode.WORKFLOW:
|
||||
if app_mode == AppMode.COMPLETION:
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
|
||||
# common
|
||||
if app_mode == AppMode.WORKFLOW:
|
||||
if app_mode == AppMode.COMPLETION:
|
||||
return 'common_completion'
|
||||
else:
|
||||
return 'common_chat'
|
||||
|
||||
Reference in New Issue
Block a user