This commit is contained in:
takatost
2024-03-19 15:32:10 +08:00
parent 24ac4996c0
commit 133d52deb9
5 changed files with 39 additions and 23 deletions

View File

@ -2,6 +2,7 @@ import json
from typing import Optional, Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelManager
@ -18,6 +19,7 @@ from services.errors.message import (
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.workflow_service import WorkflowService
class MessageService:
@ -177,7 +179,7 @@ class MessageService:
@classmethod
def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]],
message_id: str) -> list[Message]:
message_id: str, invoke_from: InvokeFrom) -> list[Message]:
if not user:
raise ValueError('user cannot be None')
@ -201,8 +203,13 @@ class MessageService:
model_manager = ModelManager()
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow_service = WorkflowService()
if invoke_from == InvokeFrom.DEBUGGER:
workflow = workflow_service.get_draft_workflow(app_model=app_model)
else:
workflow = workflow_service.get_published_workflow(app_model=app_model)
if workflow is None:
return []
@ -233,24 +240,17 @@ class MessageService:
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
if app_model_config:
model_instance = model_manager.get_model_instance(
tenant_id=app_model.tenant_id,
provider=app_model_config.model_dict['provider'],
model_type=ModelType.LLM,
model=app_model_config.model_dict['name']
)
else:
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id,
model_type=ModelType.LLM
)
suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict
if suggested_questions_after_answer.get("enabled", False) is False:
raise SuggestedQuestionsAfterAnswerDisabledError()
model_instance = model_manager.get_model_instance(
tenant_id=app_model.tenant_id,
provider=app_model_config.model_dict['provider'],
model_type=ModelType.LLM,
model=app_model_config.model_dict['name']
)
# get memory of conversation (read-only)
memory = TokenBufferMemory(
conversation=conversation,