feat: [backend] vision support (#1510)

Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
This commit is contained in:
takatost
2023-11-13 22:05:46 +08:00
committed by GitHub
parent d0e1ea8f06
commit 41d0a8b295
61 changed files with 1563 additions and 300 deletions

View File

@ -6,8 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult
from core.file.file_obj import FileObj
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
@ -16,13 +17,14 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
MessageChain, DatasetRetrieverResource
MessageChain, DatasetRetrieverResource, MessageFile
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
inputs: dict, query: str, files: List[FileObj], streaming: bool,
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
auto_generate_name: bool = True):
self.start_at = time.perf_counter()
self.task_id = task_id
@ -35,6 +37,7 @@ class ConversationMessageTask:
self.user = user
self.inputs = inputs
self.query = query
self.files = files
self.streaming = streaming
self.conversation = conversation
@ -45,6 +48,7 @@ class ConversationMessageTask:
self.message = None
self.retriever_resource = None
self.auto_generate_name = auto_generate_name
self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
@ -100,7 +104,7 @@ class ConversationMessageTask:
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode,
name='',
name='New conversation',
inputs=self.inputs,
introduction=introduction,
system_instruction=system_instruction,
@ -142,6 +146,19 @@ class ConversationMessageTask:
db.session.add(self.message)
db.session.commit()
for file in self.files:
message_file = MessageFile(
message_id=self.message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
url=file.url,
upload_file_id=file.upload_file_id,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_file)
db.session.commit()
def append_message_text(self, text: str):
if text is not None:
self._pub_handler.pub_text(text)
@ -176,7 +193,8 @@ class ConversationMessageTask:
message_was_created.send(
self.message,
conversation=self.conversation,
is_first_message=self.is_new_conversation
is_first_message=self.is_new_conversation,
auto_generate_name=self.auto_generate_name
)
if not by_stopped: