mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
feat(datasource): change datasource result type to event-stream
This commit is contained in:
@ -14,7 +14,7 @@ from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.enums import CreatedByRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import MessageFile, UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
@ -86,7 +86,7 @@ class DatasourceFileManager:
|
||||
size=len(file_binary),
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
created_by_role=CreatedByRole.ACCOUNT,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=user_id,
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_binary).hexdigest(),
|
||||
@ -133,7 +133,7 @@ class DatasourceFileManager:
|
||||
size=len(blob),
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
created_by_role=CreatedByRole.ACCOUNT,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=user_id,
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(blob).hexdigest(),
|
||||
|
||||
@ -298,28 +298,3 @@ class WebsiteCrawlMessage(BaseModel):
|
||||
class DatasourceMessage(ToolInvokeMessage):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
Datasource Invoke Message.
|
||||
"""
|
||||
|
||||
class WebsiteCrawlMessage(BaseModel):
|
||||
"""
|
||||
Website crawl message
|
||||
"""
|
||||
|
||||
job_id: str = Field(..., description="The job id")
|
||||
status: str = Field(..., description="The status of the job")
|
||||
web_info_list: Optional[list[WebSiteInfoDetail]] = []
|
||||
|
||||
class OnlineDocumentMessage(BaseModel):
|
||||
"""
|
||||
Online document message
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
workspace_name: str = Field(..., description="The workspace name")
|
||||
workspace_icon: str = Field(..., description="The workspace icon")
|
||||
total: int = Field(..., description="The total number of documents")
|
||||
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||
|
||||
@ -5,7 +5,7 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceInvokeMessage,
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDocumentPagesMessage,
|
||||
@ -33,7 +33,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def _get_online_document_pages(
|
||||
def get_online_document_pages(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
@ -51,12 +51,12 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def _get_online_document_page_content(
|
||||
def get_online_document_page_content(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_page_content(
|
||||
|
||||
@ -4,7 +4,7 @@ from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -14,23 +14,23 @@ class DatasourceFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_datasource_invoke_messages(
|
||||
cls,
|
||||
messages: Generator[DatasourceInvokeMessage, None, None],
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Transform datasource message and handle file download
|
||||
"""
|
||||
for message in messages:
|
||||
if message.type in {DatasourceInvokeMessage.MessageType.TEXT, DatasourceInvokeMessage.MessageType.LINK}:
|
||||
if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}:
|
||||
yield message
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.IMAGE and isinstance(
|
||||
message.message, DatasourceInvokeMessage.TextMessage
|
||||
elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance(
|
||||
message.message, DatasourceMessage.TextMessage
|
||||
):
|
||||
# try to download image
|
||||
try:
|
||||
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
file = DatasourceFileManager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
@ -41,20 +41,20 @@ class DatasourceFileMessageTransformer:
|
||||
|
||||
url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
|
||||
|
||||
yield DatasourceInvokeMessage(
|
||||
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceInvokeMessage.TextMessage(text=url),
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
yield DatasourceInvokeMessage(
|
||||
type=DatasourceInvokeMessage.MessageType.TEXT,
|
||||
message=DatasourceInvokeMessage.TextMessage(
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
message=DatasourceMessage.TextMessage(
|
||||
text=f"Failed to download image: {message.message.text}: {e}"
|
||||
),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
meta = message.meta or {}
|
||||
|
||||
@ -63,7 +63,7 @@ class DatasourceFileMessageTransformer:
|
||||
filename = meta.get("file_name", None)
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
if not isinstance(message.message, DatasourceInvokeMessage.BlobMessage):
|
||||
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
@ -81,18 +81,18 @@ class DatasourceFileMessageTransformer:
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield DatasourceInvokeMessage(
|
||||
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceInvokeMessage.TextMessage(text=url),
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceInvokeMessage(
|
||||
type=DatasourceInvokeMessage.MessageType.BINARY_LINK,
|
||||
message=DatasourceInvokeMessage.TextMessage(text=url),
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BINARY_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
elif message.type == DatasourceInvokeMessage.MessageType.FILE:
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file", None)
|
||||
if isinstance(file, File):
|
||||
@ -100,15 +100,15 @@ class DatasourceFileMessageTransformer:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
|
||||
if file.type == FileType.IMAGE:
|
||||
yield DatasourceInvokeMessage(
|
||||
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceInvokeMessage.TextMessage(text=url),
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceInvokeMessage(
|
||||
type=DatasourceInvokeMessage.MessageType.LINK,
|
||||
message=DatasourceInvokeMessage.TextMessage(text=url),
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
|
||||
@ -5,7 +5,6 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceInvokeMessage,
|
||||
DatasourceProviderType,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user