mirror of
https://github.com/langgenius/dify.git
synced 2026-01-29 08:16:15 +08:00
Compare commits
76 Commits
feat/node-
...
0.14.2
| Author | SHA1 | Date | |
|---|---|---|---|
| e88ea71aef | |||
| e0f1410b48 | |||
| c3c85276d1 | |||
| af2888d394 | |||
| d0dd8b7955 | |||
| 75bce2822e | |||
| 425cc1ea85 | |||
| 2bf33c4dd2 | |||
| dc19cd5d9d | |||
| dfc25dbdd0 | |||
| c4091c4c66 | |||
| ef95b1268e | |||
| e068bbec73 | |||
| 9cfd1c67b6 | |||
| 8978a6a3ff | |||
| 4e3d732934 | |||
| 03548cdfbc | |||
| 70dd69d533 | |||
| 453f324f54 | |||
| c1aa55f3ea | |||
| 4b1e13e982 | |||
| 4584eb3058 | |||
| 02a7ae15f9 | |||
| 74b1b60125 | |||
| 39df994ff9 | |||
| d9875fe232 | |||
| 26c10b9931 | |||
| 750662eb08 | |||
| 6b49889041 | |||
| 03ddee3663 | |||
| 10caab1729 | |||
| c6a72def88 | |||
| 21a31d7f8b | |||
| 2c4df108e5 | |||
| 5db8addcc6 | |||
| dd0e81d094 | |||
| 90f093eb67 | |||
| a056a9d601 | |||
| 2ad2a402fb | |||
| 3d07a94bd7 | |||
| 366857cd26 | |||
| 9578246bbb | |||
| 9ee9e9c6de | |||
| e22cc28114 | |||
| a227af3664 | |||
| 599d410d99 | |||
| 5e37ab60d8 | |||
| 0b06235527 | |||
| b8d42cdea7 | |||
| 455791b710 | |||
| 90323cd355 | |||
| c07d9e96ce | |||
| 810adb8a94 | |||
| 606aadb891 | |||
| 8f73670925 | |||
| 8c559d6231 | |||
| 786cb6859b | |||
| de8800f41a | |||
| 7a00798027 | |||
| 6ded06c6d9 | |||
| f53741c5b9 | |||
| 2681bafb76 | |||
| ac635c70cd | |||
| ef7e47d162 | |||
| 4211b9abbd | |||
| 0c0120ef27 | |||
| dacd457478 | |||
| 7b03a0316d | |||
| 52201d95b1 | |||
| e2cde628bb | |||
| 3335fa78fc | |||
| 7abc7fa573 | |||
| f6247fe67c | |||
| 996a9135f6 | |||
| 3599751f93 | |||
| 2d186e1e76 |
@ -399,6 +399,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# App configuration
|
||||
|
||||
@ -555,7 +555,8 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
|
||||
if language not in languages:
|
||||
language = "en-US"
|
||||
|
||||
name = name.strip()
|
||||
# Validates name encoding for non-Latin characters.
|
||||
name = name.strip().encode("utf-8").decode("utf-8") if name else None
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
@ -433,6 +433,11 @@ class WorkflowConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
|
||||
description="Maximum allowed depth for nested parallel executions",
|
||||
default=3,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||
default=200 * 1024,
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.14.1",
|
||||
default="0.14.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -4,3 +4,8 @@ from werkzeug.exceptions import HTTPException
|
||||
class FilenameNotExistsError(HTTPException):
|
||||
code = 400
|
||||
description = "The specified filename does not exist."
|
||||
|
||||
|
||||
class RemoteFileUploadError(HTTPException):
|
||||
code = 400
|
||||
description = "Error uploading remote file."
|
||||
|
||||
@ -6,6 +6,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@ -426,7 +427,21 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
|
||||
@ -5,8 +5,7 @@ from typing import Optional, Union
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||
|
||||
@ -76,7 +76,7 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource):
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import RemoteFileUploadError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
@ -43,10 +44,14 @@ class RemoteFileUploadApi(Resource):
|
||||
|
||||
url = args["url"]
|
||||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
|
||||
@ -3,12 +3,14 @@ import io
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return BuiltinToolManageService.update_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
args["credentials"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
|
||||
class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
@ -7,6 +8,7 @@ from controllers.service_api import api
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
conversation_delete_fields,
|
||||
conversation_infinite_scroll_pagination_fields,
|
||||
@ -39,14 +41,16 @@ class ConversationApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -104,10 +104,11 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
parser.add_argument("content", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource):
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from flask_restful import marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import RemoteFileUploadError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
@ -38,10 +39,14 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
|
||||
url = args["url"]
|
||||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
|
||||
@ -4,14 +4,17 @@ import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
QueueAgentMessageEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
@ -21,7 +24,7 @@ class AudioTrunk:
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
)
|
||||
|
||||
|
||||
def _process_future(future_queue, audio_queue):
|
||||
def _process_future(
|
||||
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
|
||||
audio_queue: queue.Queue[AudioTrunk],
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
future = future_queue.get()
|
||||
if future is None:
|
||||
break
|
||||
for audio in future.result():
|
||||
invoke_result = future.result()
|
||||
if not invoke_result:
|
||||
continue
|
||||
for audio in invoke_result:
|
||||
audio_base64 = base64.b64encode(bytes(audio))
|
||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||
except Exception as e:
|
||||
@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ""
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
|
||||
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message):
|
||||
try:
|
||||
self._msg_queue.put(message)
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
||||
self._msg_queue.put(message)
|
||||
|
||||
def _runtime(self):
|
||||
future_queue = queue.Queue()
|
||||
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
|
||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||
while True:
|
||||
try:
|
||||
@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def check_and_get_audio(self) -> AudioTrunk | None:
|
||||
def check_and_get_audio(self):
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
|
||||
@ -180,7 +180,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
@ -197,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
@ -222,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -291,9 +291,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
if not workflow_run:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
@ -331,63 +349,48 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
@ -406,10 +409,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
@ -429,10 +432,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
@ -511,7 +514,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(
|
||||
@ -522,7 +525,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@ -11,9 +10,11 @@ from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueErrorEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -37,11 +38,11 @@ class AppQueueManager:
|
||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
|
||||
q = queue.Queue()
|
||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
|
||||
def listen(self) -> Generator:
|
||||
def listen(self):
|
||||
"""
|
||||
Listen to queue
|
||||
:return:
|
||||
|
||||
@ -155,7 +155,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
@ -171,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
@ -196,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -218,7 +218,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
break
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
@ -254,9 +254,27 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
if not workflow_run:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
@ -289,64 +307,48 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
)
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
@ -366,10 +368,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
@ -390,10 +392,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -421,7 +423,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(
|
||||
|
||||
@ -188,6 +188,41 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error=event.error,
|
||||
execution_metadata=execution_metadata,
|
||||
retry_index=event.retry_index,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
@ -207,6 +242,17 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
@ -218,18 +264,10 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
@ -422,36 +460,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.error,
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
start_index=event.start_index,
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
@ -85,9 +86,9 @@ class QueueIterationStartEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
@ -139,9 +140,9 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
@ -304,9 +305,9 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
@ -314,35 +315,18 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(AppQueueEvent):
|
||||
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
"""QueueNodeRetryEvent entity"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRY
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
start_index: int # start index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
@ -368,10 +352,10 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@ -399,10 +383,10 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@ -430,10 +414,10 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
@ -438,6 +438,16 @@ class WorkflowCycleManage:
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
origin_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
}
|
||||
merged_metadata = (
|
||||
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
|
||||
if event.execution_metadata is not None
|
||||
else origin_metadata
|
||||
)
|
||||
execution_metadata = json.dumps(merged_metadata)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
@ -445,6 +455,7 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
@ -458,12 +469,8 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.index = event.start_index
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
@ -505,6 +512,12 @@ class WorkflowCycleManage:
|
||||
:param workflow_run: workflow run
|
||||
:return:
|
||||
"""
|
||||
# Attach WorkflowRun to an active session so "created_by_role" can be accessed.
|
||||
workflow_run = db.session.merge(workflow_run)
|
||||
|
||||
# Refresh to ensure any expired attributes are fully loaded
|
||||
db.session.refresh(workflow_run)
|
||||
|
||||
created_by = None
|
||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value:
|
||||
created_by_account = workflow_run.created_by_account
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
class LLMError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError):
|
||||
description = "Bad Request"
|
||||
|
||||
|
||||
class ProviderTokenNotInitError(Exception):
|
||||
class ProviderTokenNotInitError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
||||
"""
|
||||
@ -27,7 +27,7 @@ class ProviderTokenNotInitError(Exception):
|
||||
self.description = args[0] if args else self.description
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
class QuotaExceededError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the quota for a provider has been exceeded.
|
||||
"""
|
||||
@ -35,7 +35,7 @@ class QuotaExceededError(Exception):
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class AppInvokeQuotaExceededError(Exception):
|
||||
class AppInvokeQuotaExceededError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the quota for an app has been exceeded.
|
||||
"""
|
||||
@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception):
|
||||
description = "App Invoke Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
class ModelCurrentlyNotSupportError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
"""
|
||||
@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception):
|
||||
description = "Model Currently Not Support"
|
||||
|
||||
|
||||
class InvokeRateLimitError(Exception):
|
||||
class InvokeRateLimitError(ValueError):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
description = "Rate Limit Error"
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
import base64
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from . import helpers
|
||||
@ -41,7 +40,7 @@ def to_prompt_message_content(
|
||||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
) -> MultiModalPromptMessageContent:
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
@ -70,16 +69,13 @@ def to_prompt_message_content(
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
return _download_file_content(tool_file.file_key)
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
return _download_file_content(upload_file.key)
|
||||
# remote file
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
|
||||
return _download_file_content(f._storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _download_file_content(path: str, /):
|
||||
@ -110,11 +106,9 @@ def _get_encoded_string(f: File, /):
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
data = _download_file_content(f._storage_key)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
data = _download_file_content(f._storage_key)
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
||||
@ -1,32 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .models import File
|
||||
|
||||
|
||||
def get_upload_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(UploadFile).filter(
|
||||
UploadFile.id == file.related_id,
|
||||
UploadFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"upload file {file.related_id} not found")
|
||||
return record
|
||||
|
||||
|
||||
def get_tool_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(ToolFile).filter(
|
||||
ToolFile.id == file.related_id,
|
||||
ToolFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"tool file {file.related_id} not found")
|
||||
return record
|
||||
@ -47,6 +47,38 @@ class File(BaseModel):
|
||||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
|
||||
# Those properties are private, should not be exposed to the outside.
|
||||
_storage_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
tenant_id: str,
|
||||
type: FileType,
|
||||
transfer_method: FileTransferMethod,
|
||||
remote_url: Optional[str] = None,
|
||||
related_id: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
extension: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
size: int = -1,
|
||||
storage_key: str,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
tenant_id=tenant_id,
|
||||
type=type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=remote_url,
|
||||
related_id=related_id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
)
|
||||
self._storage_key = storage_key
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
return {
|
||||
|
||||
@ -118,7 +118,7 @@ class CodeExecutor:
|
||||
return response.data.stdout or ""
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict:
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]):
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
|
||||
@ -25,7 +25,7 @@ class TemplateTransformer(ABC):
|
||||
return runner_script, preload_script
|
||||
|
||||
@classmethod
|
||||
def extract_result_str_from_response(cls, response: str) -> str:
|
||||
def extract_result_str_from_response(cls, response: str):
|
||||
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
|
||||
if not result:
|
||||
raise ValueError("Failed to parse result")
|
||||
@ -33,13 +33,21 @@ class TemplateTransformer(ABC):
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
def transform_response(cls, response: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Transform response to dict
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return json.loads(cls.extract_result_str_from_response(response))
|
||||
try:
|
||||
result = json.loads(cls.extract_result_str_from_response(response))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("failed to parse response")
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("result must be a dict")
|
||||
if not all(isinstance(k, str) for k in result):
|
||||
raise ValueError("result keys must be strings")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import base64
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
|
||||
|
||||
@ -14,6 +13,7 @@ def obfuscated_token(token: str):
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
|
||||
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
|
||||
@ -24,7 +24,7 @@ BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
class MaxRetriesExceededError(Exception):
|
||||
class MaxRetriesExceededError(ValueError):
|
||||
"""Raised when the maximum number of retries is exceeded."""
|
||||
|
||||
pass
|
||||
@ -45,6 +45,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
)
|
||||
|
||||
retries = 0
|
||||
stream = kwargs.pop("stream", False)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
@ -64,11 +65,12 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
|
||||
if max_retries == 0:
|
||||
raise
|
||||
|
||||
retries += 1
|
||||
if retries <= max_retries:
|
||||
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||
|
||||
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
class OutputParserError(Exception):
|
||||
class OutputParserError(ValueError):
|
||||
pass
|
||||
|
||||
@ -4,6 +4,7 @@ from .message_entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
@ -27,6 +28,7 @@ __all__ = [
|
||||
"LLMResultChunkDelta",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"MultiModalPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
|
||||
@ -84,10 +84,10 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
format: str = Field(..., description="the format of multi-modal file")
|
||||
base64_data: str = Field("", description="the base64 data of multi-modal file")
|
||||
url: str = Field("", description="the url of multi-modal file")
|
||||
mime_type: str = Field(..., description="the mime type of multi-modal file")
|
||||
format: str = Field(default=..., description="the format of multi-modal file")
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InvokeError(Exception):
|
||||
class InvokeError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
class CredentialsValidateFailedError(Exception):
|
||||
class CredentialsValidateFailedError(ValueError):
|
||||
"""
|
||||
Credentials validate failed error
|
||||
"""
|
||||
|
||||
@ -531,7 +531,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
"data": message_content.base64_data,
|
||||
},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
@ -21,6 +21,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@ -143,7 +144,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
ping_message = UserPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
|
||||
|
||||
except Exception as ex:
|
||||
@ -187,17 +188,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
genai.configure(api_key=credentials["google_api_key"])
|
||||
google_model = genai.GenerativeModel(model_name=model)
|
||||
|
||||
history = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
elif content["role"] == "system":
|
||||
system_instruction = content["parts"][0]
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
if not history:
|
||||
raise InvokeError("The user prompt message is required. You only add a system prompt message.")
|
||||
|
||||
google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction)
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
||||
@ -404,7 +411,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
return glm_content
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
return {"role": "user", "parts": [to_part(message.content)]}
|
||||
if isinstance(message.content, list):
|
||||
text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
|
||||
message.content = "".join(c.data for c in text_contents)
|
||||
return {"role": "system", "parts": [to_part(message.content)]}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
return {
|
||||
"role": "function",
|
||||
|
||||
@ -421,7 +421,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
|
||||
# text completion model
|
||||
response = client.completions.create(
|
||||
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -593,6 +597,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
model_parameters["response_format"] = {"type": response_format}
|
||||
elif "json_schema" in model_parameters:
|
||||
del model_parameters["json_schema"]
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
|
||||
@ -4,11 +4,10 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Optional, Union, cast
|
||||
|
||||
import google.auth.transport.requests
|
||||
import requests
|
||||
import vertexai.generative_models as glm
|
||||
from anthropic import AnthropicVertex, Stream
|
||||
from anthropic.types import (
|
||||
ContentBlockDeltaEvent,
|
||||
@ -19,8 +18,6 @@ from anthropic.types import (
|
||||
MessageStreamEvent,
|
||||
)
|
||||
from google.api_core import exceptions
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@ -47,6 +44,9 @@ from core.model_runtime.errors.invoke import (
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import vertexai.generative_models as glm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -102,6 +102,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
:param stream: is stream response
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
from google.oauth2 import service_account
|
||||
|
||||
# use Anthropic official SDK references
|
||||
# - https://github.com/anthropics/anthropic-sdk-python
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
@ -406,13 +408,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
|
||||
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool":
|
||||
"""
|
||||
Convert tool messages to glm tools
|
||||
|
||||
:param tools: tool messages
|
||||
:return: glm tools
|
||||
"""
|
||||
import vertexai.generative_models as glm
|
||||
|
||||
return glm.Tool(
|
||||
function_declarations=[
|
||||
glm.FunctionDeclaration(
|
||||
@ -473,6 +477,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
import vertexai.generative_models as glm
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
||||
|
||||
@ -522,7 +530,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
|
||||
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
@ -554,7 +562,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
|
||||
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -638,13 +646,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content":
|
||||
"""
|
||||
Format a single message into glm.Content for Google API
|
||||
|
||||
:param message: one PromptMessage
|
||||
:return: glm Content representation of message
|
||||
"""
|
||||
import vertexai.generative_models as glm
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
glm_content = glm.Content(role="user", parts=[])
|
||||
|
||||
|
||||
@ -2,12 +2,9 @@ import base64
|
||||
import json
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import tiktoken
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
@ -24,6 +21,11 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
else:
|
||||
VertexTextEmbeddingModel = None
|
||||
|
||||
|
||||
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
"""
|
||||
@ -48,6 +50,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
@ -100,6 +106,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
|
||||
try:
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
project_id = credentials["vertex_project_id"]
|
||||
|
||||
@ -355,7 +355,13 @@ class TraceTask:
|
||||
def conversation_trace(self, **kwargs):
|
||||
return kwargs
|
||||
|
||||
def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id):
|
||||
def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id):
|
||||
if not workflow_run:
|
||||
raise ValueError("Workflow run not found")
|
||||
|
||||
db.session.merge(workflow_run)
|
||||
db.sessoin.refresh(workflow_run)
|
||||
|
||||
workflow_id = workflow_run.workflow_id
|
||||
tenant_id = workflow_run.tenant_id
|
||||
workflow_run_id = workflow_run.id
|
||||
|
||||
@ -83,11 +83,15 @@ class DataPostProcessor:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
reranking_provider_name = reranking_model.get("reranking_provider_name")
|
||||
reranking_model_name = reranking_model.get("reranking_model_name")
|
||||
if not reranking_provider_name or not reranking_model_name:
|
||||
return None
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
provider=reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model["reranking_model_name"],
|
||||
model=reranking_model_name,
|
||||
)
|
||||
return rerank_model_instance
|
||||
except InvokeAuthorizationError:
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import jieba
|
||||
from jieba.analyse import default_tfidf
|
||||
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
def __init__(self):
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
import jieba.analyse
|
||||
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
jieba.analyse.default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
import jieba
|
||||
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
@ -22,6 +23,8 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
|
||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
results = set()
|
||||
for token in tokens:
|
||||
results.add(token)
|
||||
|
||||
@ -103,7 +103,7 @@ class RetrievalService:
|
||||
|
||||
if exceptions:
|
||||
exception_message = ";\n".join(exceptions)
|
||||
raise Exception(exception_message)
|
||||
raise ValueError(exception_message)
|
||||
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(
|
||||
|
||||
@ -6,10 +6,8 @@ from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import jieba.posseg as pseg
|
||||
import nltk
|
||||
import numpy
|
||||
import oracledb
|
||||
from nltk.corpus import stopwords
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@ -202,6 +200,10 @@ class OracleVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# lazy import
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
@ -62,7 +62,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
user_input_form_list = app_model_config.user_input_form_list
|
||||
for input_form in user_input_form_list:
|
||||
# get type
|
||||
form_type = input_form.keys()[0]
|
||||
form_type = list(input_form.keys())[0]
|
||||
default = input_form[form_type]["default"]
|
||||
required = input_form[form_type]["required"]
|
||||
label = input_form[form_type]["label"]
|
||||
|
||||
115
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Normal file
115
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Normal file
@ -0,0 +1,115 @@
|
||||
import json
|
||||
import operator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BedrockRetrieveTool(BuiltinTool):
|
||||
bedrock_client: Any = None
|
||||
knowledge_base_id: str = None
|
||||
topk: int = None
|
||||
|
||||
def _bedrock_retrieve(
|
||||
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
|
||||
):
|
||||
try:
|
||||
retrieval_query = {"text": query_input}
|
||||
|
||||
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
|
||||
|
||||
# 如果有元数据过滤条件,则添加到检索配置中
|
||||
if metadata_filter:
|
||||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
||||
|
||||
response = self.bedrock_client.retrieve(
|
||||
knowledgeBaseId=knowledge_base_id,
|
||||
retrievalQuery=retrieval_query,
|
||||
retrievalConfiguration=retrieval_configuration,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in response.get("retrievalResults", []):
|
||||
results.append(
|
||||
{
|
||||
"content": result.get("content", {}).get("text", ""),
|
||||
"score": result.get("score", 0.0),
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise Exception(f"Error retrieving from knowledge base: {str(e)}")
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime")
|
||||
|
||||
line = 1
|
||||
if not self.knowledge_base_id:
|
||||
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
|
||||
if not self.knowledge_base_id:
|
||||
return self.create_text_message("Please provide knowledge_base_id")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
# 获取元数据过滤条件(如果存在)
|
||||
metadata_filter_str = tool_parameters.get("metadata_filter")
|
||||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
||||
|
||||
line = 4
|
||||
retrieved_docs = self._bedrock_retrieve(
|
||||
query_input=query,
|
||||
knowledge_base_id=self.knowledge_base_id,
|
||||
num_results=self.topk,
|
||||
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
|
||||
)
|
||||
|
||||
line = 5
|
||||
# Sort results by score in descending order
|
||||
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 6
|
||||
return [self.create_json_message(res) for res in sorted_docs]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
||||
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the parameters
|
||||
"""
|
||||
if not parameters.get("knowledge_base_id"):
|
||||
raise ValueError("knowledge_base_id is required")
|
||||
|
||||
if not parameters.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
||||
metadata_filter_str = parameters.get("metadata_filter")
|
||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
@ -0,0 +1,87 @@
|
||||
identity:
|
||||
name: bedrock_retrieve
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Bedrock Retrieve
|
||||
zh_Hans: Bedrock检索
|
||||
pt_BR: Bedrock Retrieve
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明
|
||||
pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base.
|
||||
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
|
||||
parameters:
|
||||
- name: knowledge_base_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Bedrock Knowledge Base ID
|
||||
zh_Hans: Bedrock知识库ID
|
||||
pt_BR: Bedrock Knowledge Base ID
|
||||
human_description:
|
||||
en_US: ID of the Bedrock Knowledge Base to retrieve from
|
||||
zh_Hans: 用于检索的Bedrock知识库ID
|
||||
pt_BR: ID of the Bedrock Knowledge Base to retrieve from
|
||||
llm_description: ID of the Bedrock Knowledge Base to retrieve from
|
||||
form: form
|
||||
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: The search query to retrieve relevant information
|
||||
zh_Hans: 用于检索相关信息的查询语句
|
||||
pt_BR: The search query to retrieve relevant information
|
||||
llm_description: The search query to retrieve relevant information
|
||||
form: llm
|
||||
|
||||
- name: topk
|
||||
type: number
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Limit for results count
|
||||
zh_Hans: 返回结果数量限制
|
||||
pt_BR: Limit for results count
|
||||
human_description:
|
||||
en_US: Maximum number of results to return
|
||||
zh_Hans: 最大返回结果数量
|
||||
pt_BR: Maximum number of results to return
|
||||
min: 1
|
||||
max: 10
|
||||
default: 5
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
pt_BR: AWS Region
|
||||
human_description:
|
||||
en_US: AWS region where the Bedrock Knowledge Base is located
|
||||
zh_Hans: Bedrock知识库所在的AWS区域
|
||||
pt_BR: AWS region where the Bedrock Knowledge Base is located
|
||||
llm_description: AWS region where the Bedrock Knowledge Base is located
|
||||
form: form
|
||||
|
||||
- name: metadata_filter
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Metadata Filter
|
||||
zh_Hans: 元数据过滤器
|
||||
pt_BR: Metadata Filter
|
||||
human_description:
|
||||
en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
form: form
|
||||
357
api/core/tools/provider/builtin/aws/tools/nova_canvas.py
Normal file
357
api/core/tools/provider/builtin/aws/tools/nova_canvas.py
Normal file
@ -0,0 +1,357 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NovaCanvasTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke AWS Bedrock Nova Canvas model for image generation
|
||||
"""
|
||||
# Get common parameters
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip()
|
||||
if not prompt:
|
||||
return self.create_text_message("Please provide a text prompt for image generation.")
|
||||
if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3":
|
||||
return self.create_text_message("Please provide an valid S3 URI for image output.")
|
||||
|
||||
task_type = tool_parameters.get("task_type", "TEXT_IMAGE")
|
||||
aws_region = tool_parameters.get("aws_region", "us-east-1")
|
||||
|
||||
# Get common image generation config parameters
|
||||
width = tool_parameters.get("width", 1024)
|
||||
height = tool_parameters.get("height", 1024)
|
||||
cfg_scale = tool_parameters.get("cfg_scale", 8.0)
|
||||
negative_prompt = tool_parameters.get("negative_prompt", "")
|
||||
seed = tool_parameters.get("seed", 0)
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
|
||||
# Handle S3 image if provided
|
||||
image_input_s3uri = tool_parameters.get("image_input_s3uri", "")
|
||||
if task_type != "TEXT_IMAGE":
|
||||
if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3":
|
||||
return self.create_text_message("Please provide a valid S3 URI for image to image generation.")
|
||||
|
||||
# Parse S3 URI
|
||||
parsed_uri = urlparse(image_input_s3uri)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
# Initialize S3 client and download image
|
||||
s3_client = boto3.client("s3")
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
image_data = response["Body"].read()
|
||||
|
||||
# Base64 encode the image
|
||||
input_image = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
try:
|
||||
# Initialize Bedrock client
|
||||
bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region)
|
||||
|
||||
# Base image generation config
|
||||
image_generation_config = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"cfgScale": cfg_scale,
|
||||
"seed": seed,
|
||||
"numberOfImages": 1,
|
||||
"quality": quality,
|
||||
}
|
||||
|
||||
# Prepare request body based on task type
|
||||
body = {"imageGenerationConfig": image_generation_config}
|
||||
|
||||
if task_type == "TEXT_IMAGE":
|
||||
body["taskType"] = "TEXT_IMAGE"
|
||||
body["textToImageParams"] = {"text": prompt}
|
||||
if negative_prompt:
|
||||
body["textToImageParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "COLOR_GUIDED_GENERATION":
|
||||
colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680")
|
||||
if not self._validate_color_string(colors):
|
||||
return self.create_text_message("Please provide valid colors in hexadecimal format.")
|
||||
|
||||
body["taskType"] = "COLOR_GUIDED_GENERATION"
|
||||
body["colorGuidedGenerationParams"] = {
|
||||
"colors": colors.split("-"),
|
||||
"referenceImage": input_image,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "IMAGE_VARIATION":
|
||||
similarity_strength = tool_parameters.get("similarity_strength", 0.5)
|
||||
|
||||
body["taskType"] = "IMAGE_VARIATION"
|
||||
body["imageVariationParams"] = {
|
||||
"images": [input_image],
|
||||
"similarityStrength": similarity_strength,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["imageVariationParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "INPAINTING":
|
||||
mask_prompt = tool_parameters.get("mask_prompt")
|
||||
if not mask_prompt:
|
||||
return self.create_text_message("Please provide a mask prompt for image inpainting.")
|
||||
|
||||
body["taskType"] = "INPAINTING"
|
||||
body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt}
|
||||
if negative_prompt:
|
||||
body["inPaintingParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "OUTPAINTING":
|
||||
mask_prompt = tool_parameters.get("mask_prompt")
|
||||
if not mask_prompt:
|
||||
return self.create_text_message("Please provide a mask prompt for image outpainting.")
|
||||
outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT")
|
||||
|
||||
body["taskType"] = "OUTPAINTING"
|
||||
body["outPaintingParams"] = {
|
||||
"image": input_image,
|
||||
"maskPrompt": mask_prompt,
|
||||
"outPaintingMode": outpainting_mode,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["outPaintingParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "BACKGROUND_REMOVAL":
|
||||
body["taskType"] = "BACKGROUND_REMOVAL"
|
||||
body["backgroundRemovalParams"] = {"image": input_image}
|
||||
|
||||
else:
|
||||
return self.create_text_message(f"Unsupported task type: {task_type}")
|
||||
|
||||
# Call Nova Canvas model
|
||||
response = bedrock.invoke_model(
|
||||
body=json.dumps(body),
|
||||
modelId="amazon.nova-canvas-v1:0",
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
# Process response
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if response_body.get("error"):
|
||||
raise Exception(f"Error in model response: {response_body.get('error')}")
|
||||
base64_image = response_body.get("images")[0]
|
||||
|
||||
# Upload to S3 if image_output_s3uri is provided
|
||||
try:
|
||||
# Parse S3 URI for output
|
||||
parsed_uri = urlparse(image_output_s3uri)
|
||||
output_bucket = parsed_uri.netloc
|
||||
output_base_path = parsed_uri.path.lstrip("/")
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_key = f"{output_base_path}/canvas-output-{timestamp}.png"
|
||||
|
||||
# Initialize S3 client if not already done
|
||||
s3_client = boto3.client("s3", region_name=aws_region)
|
||||
|
||||
# Decode base64 image and upload to S3
|
||||
image_data = base64.b64decode(base64_image)
|
||||
s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png")
|
||||
logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to upload image to S3")
|
||||
# Return image
|
||||
return [
|
||||
self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"),
|
||||
self.create_blob_message(
|
||||
blob=base64.b64decode(base64_image),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
),
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to generate image: {str(e)}")
|
||||
|
||||
def _validate_color_string(self, color_string) -> bool:
|
||||
color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$"
|
||||
|
||||
if re.match(color_pattern, color_string):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description of the image you want to generate or modify",
|
||||
zh_Hans="您想要生成或修改的图像的文本描述",
|
||||
),
|
||||
llm_description="Describe the image you want to generate or how you want to modify the input image",
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_input_s3uri",
|
||||
label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_output_s3uri",
|
||||
label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="width",
|
||||
label=I18nObject(en_US="Width", zh_Hans="宽度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=1024,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="height",
|
||||
label=I18nObject(en_US="Height", zh_Hans="高度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=1024,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="cfg_scale",
|
||||
label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=8.0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="negative_prompt",
|
||||
label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="",
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="aws_region",
|
||||
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="us-east-1",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="task_type",
|
||||
label=I18nObject(en_US="Task Type", zh_Hans="任务类型"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="TEXT_IMAGE",
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="quality",
|
||||
label=I18nObject(en_US="Quality", zh_Hans="质量"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="standard",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="colors",
|
||||
label=I18nObject(en_US="Colors", zh_Hans="颜色"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680",
|
||||
zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="similarity_strength",
|
||||
label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0.5,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="How similar the generated image should be to the input image (0.0 to 1.0)",
|
||||
zh_Hans="生成的图像应该与输入图像的相似程度(0.0到1.0)",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="mask_prompt",
|
||||
label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description to generate mask for inpainting/outpainting",
|
||||
zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="outpainting_mode",
|
||||
label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="DEFAULT",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Mode for outpainting (DEFAULT or other supported modes)",
|
||||
zh_Hans="外补绘制的模式(DEFAULT或其他支持的模式)",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
return parameters
|
||||
175
api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml
Normal file
175
api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml
Normal file
@ -0,0 +1,175 @@
|
||||
identity:
|
||||
name: nova_canvas
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS Bedrock Nova Canvas
|
||||
zh_Hans: AWS Bedrock Nova Canvas
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html
|
||||
zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。
|
||||
llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal.
|
||||
parameters:
|
||||
- name: task_type
|
||||
type: string
|
||||
required: false
|
||||
default: TEXT_IMAGE
|
||||
label:
|
||||
en_US: Task Type
|
||||
zh_Hans: 任务类型
|
||||
human_description:
|
||||
en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL)
|
||||
zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除)
|
||||
form: llm
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Text description of the image you want to generate or modify
|
||||
zh_Hans: 您想要生成或修改的图像的文本描述
|
||||
llm_description: Describe the image you want to generate or how you want to modify the input image
|
||||
form: llm
|
||||
- name: image_input_s3uri
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Input image s3 uri
|
||||
zh_Hans: 输入图片的s3 uri
|
||||
human_description:
|
||||
en_US: The input image to modify (required for all modes except TEXT_IMAGE)
|
||||
zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要)
|
||||
llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE.
|
||||
form: llm
|
||||
- name: image_output_s3uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Output S3 URI
|
||||
zh_Hans: 输出S3 URI
|
||||
human_description:
|
||||
en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png
|
||||
zh_Hans: 生成的图像将保存到的S3 URI。如果提供,图像将以canvas-output-{timestamp}.png的格式上传
|
||||
llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename.
|
||||
form: form
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Negative Prompt
|
||||
zh_Hans: 负面提示词
|
||||
human_description:
|
||||
en_US: Things you don't want in the generated image
|
||||
zh_Hans: 您不想在生成的图像中出现的内容
|
||||
form: llm
|
||||
- name: width
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Width
|
||||
zh_Hans: 宽度
|
||||
human_description:
|
||||
en_US: Width of the generated image
|
||||
zh_Hans: 生成图像的宽度
|
||||
form: form
|
||||
default: 1024
|
||||
- name: height
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Height
|
||||
zh_Hans: 高度
|
||||
human_description:
|
||||
en_US: Height of the generated image
|
||||
zh_Hans: 生成图像的高度
|
||||
form: form
|
||||
default: 1024
|
||||
- name: cfg_scale
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: CFG Scale
|
||||
zh_Hans: CFG比例
|
||||
human_description:
|
||||
en_US: How strongly the image should conform to the prompt
|
||||
zh_Hans: 图像应该多大程度上符合提示词
|
||||
form: form
|
||||
default: 8.0
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子值
|
||||
human_description:
|
||||
en_US: Random seed for image generation
|
||||
zh_Hans: 图像生成的随机种子
|
||||
form: form
|
||||
default: 0
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
default: us-east-1
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: AWS region for Bedrock service
|
||||
zh_Hans: Bedrock 服务的 AWS 区域
|
||||
form: form
|
||||
- name: quality
|
||||
type: string
|
||||
required: false
|
||||
default: standard
|
||||
label:
|
||||
en_US: Quality
|
||||
zh_Hans: 质量
|
||||
human_description:
|
||||
en_US: Quality of the generated image (standard or premium)
|
||||
zh_Hans: 生成图像的质量(标准或高级)
|
||||
form: form
|
||||
- name: colors
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Colors
|
||||
zh_Hans: 颜色
|
||||
human_description:
|
||||
en_US: List of colors for color-guided generation
|
||||
zh_Hans: 颜色引导生成的颜色列表
|
||||
form: form
|
||||
- name: similarity_strength
|
||||
type: number
|
||||
required: false
|
||||
default: 0.5
|
||||
label:
|
||||
en_US: Similarity Strength
|
||||
zh_Hans: 相似度强度
|
||||
human_description:
|
||||
en_US: How similar the generated image should be to the input image (0.0 to 1.0)
|
||||
zh_Hans: 生成的图像应该与输入图像的相似程度(0.0到1.0)
|
||||
form: form
|
||||
- name: mask_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Mask Prompt
|
||||
zh_Hans: 蒙版提示词
|
||||
human_description:
|
||||
en_US: Text description to generate mask for inpainting/outpainting
|
||||
zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述
|
||||
form: llm
|
||||
- name: outpainting_mode
|
||||
type: string
|
||||
required: false
|
||||
default: DEFAULT
|
||||
label:
|
||||
en_US: Outpainting Mode
|
||||
zh_Hans: 外补绘制模式
|
||||
human_description:
|
||||
en_US: Mode for outpainting (DEFAULT or other supported modes)
|
||||
zh_Hans: 外补绘制的模式(DEFAULT或其他支持的模式)
|
||||
form: form
|
||||
371
api/core/tools/provider/builtin/aws/tools/nova_reel.py
Normal file
371
api/core/tools/provider/builtin/aws/tools/nova_reel.py
Normal file
@ -0,0 +1,371 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from PIL import Image
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOVA_REEL_DEFAULT_REGION = "us-east-1"
|
||||
NOVA_REEL_DEFAULT_DIMENSION = "1280x720"
|
||||
NOVA_REEL_DEFAULT_FPS = 24
|
||||
NOVA_REEL_DEFAULT_DURATION = 6
|
||||
NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0"
|
||||
NOVA_REEL_STATUS_CHECK_INTERVAL = 5
|
||||
|
||||
# Image requirements
|
||||
NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280
|
||||
NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720
|
||||
NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB"
|
||||
|
||||
|
||||
class NovaReelTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke AWS Bedrock Nova Reel model for video generation.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user making the request
|
||||
tool_parameters: Dictionary containing the tool parameters
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage containing either the video content or status information
|
||||
"""
|
||||
try:
|
||||
# Validate and extract parameters
|
||||
params = self._validate_and_extract_parameters(tool_parameters)
|
||||
if isinstance(params, ToolInvokeMessage):
|
||||
return params
|
||||
|
||||
# Initialize AWS clients
|
||||
bedrock, s3_client = self._initialize_aws_clients(params["aws_region"])
|
||||
|
||||
# Prepare model input
|
||||
model_input = self._prepare_model_input(params, s3_client)
|
||||
if isinstance(model_input, ToolInvokeMessage):
|
||||
return model_input
|
||||
|
||||
# Start video generation
|
||||
invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"])
|
||||
invocation_arn = invocation["invocationArn"]
|
||||
|
||||
# Handle async/sync mode
|
||||
return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"])
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||
error_message = e.response.get("Error", {}).get("Message", str(e))
|
||||
logger.exception(f"AWS API error: {error_code} - {error_message}")
|
||||
return self.create_text_message(f"AWS service error: {error_code} - {error_message}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(f"Failed to generate video: {str(e)}")
|
||||
|
||||
def _validate_and_extract_parameters(
|
||||
self, tool_parameters: dict[str, Any]
|
||||
) -> Union[dict[str, Any], ToolInvokeMessage]:
|
||||
"""Validate and extract parameters from the input dictionary."""
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip()
|
||||
|
||||
# Validate required parameters
|
||||
if not prompt:
|
||||
return self.create_text_message("Please provide a text prompt for video generation.")
|
||||
if not video_output_s3uri:
|
||||
return self.create_text_message("Please provide an S3 URI for video output.")
|
||||
|
||||
# Validate S3 URI format
|
||||
if not video_output_s3uri.startswith("s3://"):
|
||||
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
|
||||
|
||||
# Ensure S3 URI ends with '/'
|
||||
video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/"
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"video_output_s3uri": video_output_s3uri,
|
||||
"image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(),
|
||||
"aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION),
|
||||
"dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION),
|
||||
"seed": int(tool_parameters.get("seed", 0)),
|
||||
"fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)),
|
||||
"duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)),
|
||||
"async_mode": bool(tool_parameters.get("async", True)),
|
||||
}
|
||||
|
||||
def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]:
|
||||
"""Initialize AWS Bedrock and S3 clients."""
|
||||
bedrock = boto3.client(service_name="bedrock-runtime", region_name=region)
|
||||
s3_client = boto3.client("s3", region_name=region)
|
||||
return bedrock, s3_client
|
||||
|
||||
def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]:
|
||||
"""Prepare the input for the Nova Reel model."""
|
||||
model_input = {
|
||||
"taskType": "TEXT_VIDEO",
|
||||
"textToVideoParams": {"text": params["prompt"]},
|
||||
"videoGenerationConfig": {
|
||||
"durationSeconds": params["duration"],
|
||||
"fps": params["fps"],
|
||||
"dimension": params["dimension"],
|
||||
"seed": params["seed"],
|
||||
},
|
||||
}
|
||||
|
||||
# Add image if provided
|
||||
if params["image_input_s3uri"]:
|
||||
try:
|
||||
image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"])
|
||||
if not image_data:
|
||||
return self.create_text_message("Failed to retrieve image from S3")
|
||||
|
||||
# Process and validate image
|
||||
processed_image = self._process_and_validate_image(image_data)
|
||||
if isinstance(processed_image, ToolInvokeMessage):
|
||||
return processed_image
|
||||
|
||||
# Convert processed image to base64
|
||||
img_buffer = BytesIO()
|
||||
processed_image.save(img_buffer, format="PNG")
|
||||
img_buffer.seek(0)
|
||||
input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
|
||||
|
||||
model_input["textToVideoParams"]["images"] = [
|
||||
{"format": "png", "source": {"bytes": input_image_base64}}
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing input image: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(f"Failed to process input image: {str(e)}")
|
||||
|
||||
return model_input
|
||||
|
||||
def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]:
|
||||
"""
|
||||
Process and validate the input image according to Nova Reel requirements.
|
||||
|
||||
Requirements:
|
||||
- Must be 1280x720 pixels
|
||||
- Must be RGB format (8 bits per channel)
|
||||
- If PNG, alpha channel must not have transparent/translucent pixels
|
||||
"""
|
||||
try:
|
||||
# Open image
|
||||
img = Image.open(BytesIO(image_data))
|
||||
|
||||
# Convert RGBA to RGB if needed, ensuring no transparency
|
||||
if img.mode == "RGBA":
|
||||
# Check for transparency
|
||||
if img.getchannel("A").getextrema()[0] < 255:
|
||||
return self.create_text_message(
|
||||
"PNG image contains transparent or translucent pixels, which is not supported. "
|
||||
"Please provide an image without transparency."
|
||||
)
|
||||
# Convert to RGB
|
||||
img = img.convert("RGB")
|
||||
elif img.mode != "RGB":
|
||||
# Convert any other mode to RGB
|
||||
img = img.convert("RGB")
|
||||
|
||||
# Validate/adjust dimensions
|
||||
if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT):
|
||||
logger.warning(
|
||||
f"Image dimensions {img.size} do not match required dimensions "
|
||||
f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..."
|
||||
)
|
||||
img = img.resize(
|
||||
(NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
# Validate bit depth
|
||||
if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE:
|
||||
return self.create_text_message(
|
||||
f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel"
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(
|
||||
"Failed to process image. Please ensure the image is a valid JPEG or PNG file."
|
||||
)
|
||||
|
||||
def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]:
|
||||
"""Download and return image data from S3."""
|
||||
parsed_uri = urlparse(s3_uri)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
return response["Body"].read()
|
||||
|
||||
def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]:
|
||||
"""Start the async video generation process."""
|
||||
return bedrock.start_async_invoke(
|
||||
modelId=NOVA_REEL_MODEL_ID,
|
||||
modelInput=model_input,
|
||||
outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}},
|
||||
)
|
||||
|
||||
def _handle_generation_mode(
|
||||
self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool
|
||||
) -> ToolInvokeMessage:
|
||||
"""Handle async or sync video generation mode."""
|
||||
invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
|
||||
video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
|
||||
video_uri = f"{video_path}/output.mp4"
|
||||
|
||||
if async_mode:
|
||||
return self.create_text_message(
|
||||
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
|
||||
f"Video will be available at: {video_uri}"
|
||||
)
|
||||
|
||||
return self._wait_for_completion(bedrock, s3_client, invocation_arn)
|
||||
|
||||
def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage:
|
||||
"""Wait for video generation completion and handle the result."""
|
||||
while True:
|
||||
status_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
|
||||
status = status_response["status"]
|
||||
video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
|
||||
|
||||
if status == "Completed":
|
||||
return self._handle_completed_video(s3_client, video_path)
|
||||
elif status == "Failed":
|
||||
failure_message = status_response.get("failureMessage", "Unknown error")
|
||||
return self.create_text_message(f"Video generation failed.\nError: {failure_message}")
|
||||
elif status == "InProgress":
|
||||
time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL)
|
||||
else:
|
||||
return self.create_text_message(f"Unexpected status: {status}")
|
||||
|
||||
def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage:
|
||||
"""Handle completed video generation and return the result."""
|
||||
parsed_uri = urlparse(video_path)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/") + "/output.mp4"
|
||||
|
||||
try:
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
video_content = response["Body"].read()
|
||||
return [
|
||||
self.create_text_message(f"Video is available at: {video_path}/output.mp4"),
|
||||
self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"),
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading video: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(
|
||||
f"Video generation completed but failed to download video: {str(e)}\n"
|
||||
f"Video is available at: s3://{bucket}/{key}"
|
||||
)
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""Define the tool's runtime parameters."""
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述"
|
||||
),
|
||||
llm_description="Describe the video you want to generate",
|
||||
),
|
||||
ToolParameter(
|
||||
name="video_output_s3uri",
|
||||
label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="dimension",
|
||||
label=I18nObject(en_US="Dimension", zh_Hans="尺寸"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_DIMENSION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="duration",
|
||||
label=I18nObject(en_US="Duration", zh_Hans="时长"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_DURATION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="fps",
|
||||
label=I18nObject(en_US="FPS", zh_Hans="帧率"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_FPS,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="async",
|
||||
label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"),
|
||||
type=ToolParameter.ToolParameterType.BOOLEAN,
|
||||
required=False,
|
||||
default=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)",
|
||||
zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="aws_region",
|
||||
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_REGION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_input_s3uri",
|
||||
label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame",
|
||||
zh_Hans="用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
return parameters
|
||||
124
api/core/tools/provider/builtin/aws/tools/nova_reel.yaml
Normal file
124
api/core/tools/provider/builtin/aws/tools/nova_reel.yaml
Normal file
@ -0,0 +1,124 @@
|
||||
identity:
|
||||
name: nova_reel
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS Bedrock Nova Reel
|
||||
zh_Hans: AWS Bedrock Nova Reel
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
|
||||
zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
|
||||
llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution.
|
||||
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Text description of the video you want to generate
|
||||
zh_Hans: 您想要生成的视频的文本描述
|
||||
llm_description: Describe the video you want to generate
|
||||
form: llm
|
||||
|
||||
- name: video_output_s3uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Output S3 URI
|
||||
zh_Hans: 输出S3 URI
|
||||
human_description:
|
||||
en_US: S3 URI where the generated video will be stored
|
||||
zh_Hans: 生成的视频将存储的S3 URI
|
||||
form: form
|
||||
|
||||
- name: dimension
|
||||
type: string
|
||||
required: false
|
||||
default: 1280x720
|
||||
label:
|
||||
en_US: Dimension
|
||||
zh_Hans: 尺寸
|
||||
human_description:
|
||||
en_US: Video dimensions (width x height)
|
||||
zh_Hans: 视频尺寸(宽 x 高)
|
||||
form: form
|
||||
|
||||
- name: duration
|
||||
type: number
|
||||
required: false
|
||||
default: 6
|
||||
label:
|
||||
en_US: Duration
|
||||
zh_Hans: 时长
|
||||
human_description:
|
||||
en_US: Video duration in seconds
|
||||
zh_Hans: 视频时长(秒)
|
||||
form: form
|
||||
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
default: 0
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子值
|
||||
human_description:
|
||||
en_US: Random seed for video generation
|
||||
zh_Hans: 视频生成的随机种子
|
||||
form: form
|
||||
|
||||
- name: fps
|
||||
type: number
|
||||
required: false
|
||||
default: 24
|
||||
label:
|
||||
en_US: FPS
|
||||
zh_Hans: 帧率
|
||||
human_description:
|
||||
en_US: Frames per second for the generated video
|
||||
zh_Hans: 生成视频的每秒帧数
|
||||
form: form
|
||||
|
||||
- name: async
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
label:
|
||||
en_US: Async Mode
|
||||
zh_Hans: 异步模式
|
||||
human_description:
|
||||
en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion)
|
||||
zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成)
|
||||
form: llm
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
default: us-east-1
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: AWS region for Bedrock service
|
||||
zh_Hans: Bedrock 服务的 AWS 区域
|
||||
form: form
|
||||
|
||||
- name: image_input_s3uri
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Input Image S3 URI
|
||||
zh_Hans: 输入图像S3 URI
|
||||
human_description:
|
||||
en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame
|
||||
zh_Hans: 用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI
|
||||
form: llm
|
||||
|
||||
development:
|
||||
dependencies:
|
||||
- boto3
|
||||
- pillow
|
||||
80
api/core/tools/provider/builtin/aws/tools/s3_operator.py
Normal file
80
api/core/tools/provider/builtin/aws/tools/s3_operator.py
Normal file
@ -0,0 +1,80 @@
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class S3Operator(BuiltinTool):
|
||||
s3_client: Any = None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
# Initialize S3 client if not already done
|
||||
if not self.s3_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
else:
|
||||
self.s3_client = boto3.client("s3")
|
||||
|
||||
# Parse S3 URI
|
||||
s3_uri = tool_parameters.get("s3_uri")
|
||||
if not s3_uri:
|
||||
return self.create_text_message("s3_uri parameter is required")
|
||||
|
||||
parsed_uri = urlparse(s3_uri)
|
||||
if parsed_uri.scheme != "s3":
|
||||
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
|
||||
|
||||
bucket = parsed_uri.netloc
|
||||
# Remove leading slash from key
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
operation_type = tool_parameters.get("operation_type", "read")
|
||||
generate_presign_url = tool_parameters.get("generate_presign_url", False)
|
||||
presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour
|
||||
|
||||
if operation_type == "write":
|
||||
text_content = tool_parameters.get("text_content")
|
||||
if not text_content:
|
||||
return self.create_text_message("text_content parameter is required for write operation")
|
||||
|
||||
# Write content to S3
|
||||
self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8"))
|
||||
result = f"s3://{bucket}/{key}"
|
||||
|
||||
# Generate presigned URL for the written object if requested
|
||||
if generate_presign_url:
|
||||
result = self.s3_client.generate_presigned_url(
|
||||
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
|
||||
)
|
||||
|
||||
else: # read operation
|
||||
# Get object from S3
|
||||
response = self.s3_client.get_object(Bucket=bucket, Key=key)
|
||||
result = response["Body"].read().decode("utf-8")
|
||||
|
||||
# Generate presigned URL if requested
|
||||
if generate_presign_url:
|
||||
result = self.s3_client.generate_presigned_url(
|
||||
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
|
||||
)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except self.s3_client.exceptions.NoSuchBucket:
|
||||
return self.create_text_message(f"Bucket '{bucket}' does not exist")
|
||||
except self.s3_client.exceptions.NoSuchKey:
|
||||
return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'")
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception: {str(e)}")
|
||||
98
api/core/tools/provider/builtin/aws/tools/s3_operator.yaml
Normal file
98
api/core/tools/provider/builtin/aws/tools/s3_operator.yaml
Normal file
@ -0,0 +1,98 @@
|
||||
identity:
|
||||
name: s3_operator
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS S3 Operator
|
||||
zh_Hans: AWS S3 读写器
|
||||
pt_BR: AWS S3 Operator
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: AWS S3 Writer and Reader
|
||||
zh_Hans: 读写S3 bucket中的文件
|
||||
pt_BR: AWS S3 Writer and Reader
|
||||
llm: AWS S3 Writer and Reader
|
||||
parameters:
|
||||
- name: text_content
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: The text to write
|
||||
zh_Hans: 待写入的文本
|
||||
pt_BR: The text to write
|
||||
human_description:
|
||||
en_US: The text to write
|
||||
zh_Hans: 待写入的文本
|
||||
pt_BR: The text to write
|
||||
llm_description: The text to write
|
||||
form: llm
|
||||
- name: s3_uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: s3 uri
|
||||
zh_Hans: s3 uri
|
||||
pt_BR: s3 uri
|
||||
human_description:
|
||||
en_US: s3 uri
|
||||
zh_Hans: s3 uri
|
||||
pt_BR: s3 uri
|
||||
llm_description: s3 uri
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: region of bucket
|
||||
zh_Hans: bucket 所在的region
|
||||
pt_BR: region of bucket
|
||||
human_description:
|
||||
en_US: region of bucket
|
||||
zh_Hans: bucket 所在的region
|
||||
pt_BR: region of bucket
|
||||
llm_description: region of bucket
|
||||
form: form
|
||||
- name: operation_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: operation type
|
||||
zh_Hans: 操作类型
|
||||
pt_BR: operation type
|
||||
human_description:
|
||||
en_US: operation type
|
||||
zh_Hans: 操作类型
|
||||
pt_BR: operation type
|
||||
default: read
|
||||
options:
|
||||
- value: read
|
||||
label:
|
||||
en_US: read
|
||||
zh_Hans: 读
|
||||
- value: write
|
||||
label:
|
||||
en_US: write
|
||||
zh_Hans: 写
|
||||
form: form
|
||||
- name: generate_presign_url
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Generate presigned URL
|
||||
zh_Hans: 生成预签名URL
|
||||
human_description:
|
||||
en_US: Whether to generate a presigned URL for the S3 object
|
||||
zh_Hans: 是否生成S3对象的预签名URL
|
||||
default: false
|
||||
form: form
|
||||
- name: presign_expiry
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Presigned URL expiration time
|
||||
zh_Hans: 预签名URL有效期
|
||||
human_description:
|
||||
en_US: Expiration time in seconds for the presigned URL
|
||||
zh_Hans: 预签名URL的有效期(秒)
|
||||
default: 3600
|
||||
form: form
|
||||
@ -1,32 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.file import FileTransferMethod, FileType
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from factories import file_factory
|
||||
|
||||
|
||||
class VectorizerProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
mapping = {
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"type": FileType.IMAGE,
|
||||
"id": "test_id",
|
||||
"url": "https://cloud.dify.ai/logo/logo-site.png",
|
||||
}
|
||||
test_img = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id="__test_123",
|
||||
)
|
||||
try:
|
||||
VectorizerTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={"mode": "test", "image": test_img},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
return
|
||||
|
||||
@ -210,7 +210,7 @@ class ApiTool(Tool):
|
||||
)
|
||||
return response
|
||||
else:
|
||||
raise ValueError(f"Invalid http method {self.method}")
|
||||
raise ValueError(f"Invalid http method {method}")
|
||||
|
||||
def _convert_body_property_any_of(
|
||||
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
|
||||
|
||||
@ -8,9 +8,10 @@ from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from httpx import get
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
@ -94,12 +95,11 @@ class ToolFileManager:
|
||||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = get(file_url)
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to download file from {file_url}")
|
||||
raise
|
||||
except httpx.TimeoutException as e:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = guess_type(file_url)[0] or "octet/stream"
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
|
||||
@ -21,6 +21,7 @@ from .variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
@ -43,6 +44,7 @@ __all__ = [
|
||||
"ArraySegment",
|
||||
"ArrayStringSegment",
|
||||
"ArrayStringVariable",
|
||||
"ArrayVariable",
|
||||
"FileSegment",
|
||||
"FileVariable",
|
||||
"FloatSegment",
|
||||
|
||||
@ -10,6 +10,7 @@ from .segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
@ -52,19 +53,23 @@ class ObjectVariable(ObjectSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayAnyVariable(ArrayAnySegment, Variable):
|
||||
class ArrayVariable(ArraySegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayStringVariable(ArrayStringSegment, Variable):
|
||||
class ArrayAnyVariable(ArrayAnySegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayNumberVariable(ArrayNumberSegment, Variable):
|
||||
class ArrayStringVariable(ArrayStringSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayObjectVariable(ArrayObjectSegment, Variable):
|
||||
class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: Optional[int] = Field(description="exception count", default=0)
|
||||
exceptions_count: int = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
@ -97,11 +97,10 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(BaseNodeEvent):
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
start_at: datetime = Field(..., description="retry start time")
|
||||
start_index: int = Field(..., description="retry start index")
|
||||
|
||||
|
||||
###########################################
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
@ -170,7 +171,9 @@ class Graph(BaseModel):
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
parent_parallel_id=parallel.parent_parallel_id,
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
|
||||
@ -641,7 +641,6 @@ class GraphEngine:
|
||||
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if node_instance.should_retry and retries < max_retries:
|
||||
retries += 1
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=node_instance.id,
|
||||
@ -649,14 +648,14 @@ class GraphEngine:
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
error=run_result.error,
|
||||
retry_index=retries,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=run_result.error,
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
start_index=self.graph_runtime_state.node_run_steps,
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
continue
|
||||
|
||||
@ -147,6 +147,8 @@ class AnswerStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
|
||||
if (
|
||||
|
||||
@ -60,7 +60,6 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
|
||||
@ -5,6 +6,8 @@ from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
@ -31,13 +34,22 @@ class StreamProcessor(ABC):
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
if finished_node_id not in self.graph.edge_mapping:
|
||||
logger.warning(f"node {finished_node_id} has no edge mapping")
|
||||
return
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (
|
||||
edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify
|
||||
):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
# remove unreachable nodes
|
||||
# FIXME: because of the code branch can combine directly, so for answer node
|
||||
# we remove the node maybe shortcut the answer node, so comment this code for now
|
||||
# there is not effect on the answer node and the workflow, when we have a better solution
|
||||
# we can open this code. Issues: #11542 #9560 #10638 #10564
|
||||
|
||||
# reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
class BaseNodeError(Exception):
|
||||
class BaseNodeError(ValueError):
|
||||
"""Base class for node errors."""
|
||||
|
||||
pass
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result, self.node_data.outputs)
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _check_string(self, value: str, variable: str) -> str:
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, str):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise OutputValidationError(
|
||||
@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, int | float):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise OutputValidationError(
|
||||
@ -118,14 +116,12 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
return value
|
||||
|
||||
def _transform_result(
|
||||
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
|
||||
) -> dict:
|
||||
"""
|
||||
Transform result
|
||||
:param result: result
|
||||
:param output_schema: output schema
|
||||
:return:
|
||||
"""
|
||||
self,
|
||||
result: Mapping[str, Any],
|
||||
output_schema: Optional[dict[str, CodeNodeData.Output]],
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
@ -8,12 +9,6 @@ import docx
|
||||
import pandas as pd
|
||||
import pypdfium2 # type: ignore
|
||||
import yaml # type: ignore
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.email import partition_email
|
||||
from unstructured.partition.epub import partition_epub
|
||||
from unstructured.partition.msg import partition_msg
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
@ -28,6 +23,8 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
"""
|
||||
@ -183,10 +180,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOC/DOCX file.
|
||||
For now support only paragraph and table add more if needed
|
||||
"""
|
||||
try:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
text = []
|
||||
# Process paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text.append(paragraph.text)
|
||||
|
||||
# Process tables
|
||||
for table in doc.tables:
|
||||
# Table header
|
||||
try:
|
||||
# table maybe cause errors so ignore it.
|
||||
if len(table.rows) > 0 and table.rows[0].cells is not None:
|
||||
# Check if any cell in the table has text
|
||||
has_content = False
|
||||
for row in table.rows:
|
||||
if any(cell.text.strip() for cell in row.cells):
|
||||
has_content = True
|
||||
break
|
||||
|
||||
if has_content:
|
||||
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
|
||||
for row in table.rows[1:]:
|
||||
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
|
||||
@ -256,6 +286,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_ppt(file=file)
|
||||
@ -265,6 +297,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
|
||||
@ -287,6 +322,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_epub(file=file)
|
||||
@ -296,6 +333,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_eml(file_content: bytes) -> str:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_email(file=file)
|
||||
@ -305,6 +344,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_msg(file_content: bytes) -> str:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_msg(file=file)
|
||||
|
||||
@ -135,6 +135,8 @@ class EndStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in {
|
||||
NodeType.IF_ELSE.value,
|
||||
|
||||
@ -35,4 +35,4 @@ class FailBranchSourceHandle(StrEnum):
|
||||
|
||||
|
||||
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
||||
@ -39,15 +39,9 @@ class RunRetryEvent(BaseModel):
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class SingleStepRetryEvent(BaseModel):
|
||||
class SingleStepRetryEvent(NodeRunResult):
|
||||
"""Single step retry event"""
|
||||
|
||||
status: str = WorkflowNodeExecutionStatus.RETRY.value
|
||||
|
||||
inputs: dict | None = Field(..., description="input")
|
||||
error: str = Field(..., description="error")
|
||||
outputs: dict = Field(..., description="output")
|
||||
retry_index: int = Field(..., description="Retry attempt number")
|
||||
error: str = Field(..., description="error")
|
||||
elapsed_time: float = Field(..., description="elapsed time")
|
||||
execution_metadata: dict | None = Field(..., description="execution metadata")
|
||||
|
||||
@ -16,3 +16,7 @@ class InvalidHttpMethodError(HttpRequestNodeError):
|
||||
|
||||
class ResponseSizeError(HttpRequestNodeError):
|
||||
"""Raised when the response size exceeds the allowed threshold."""
|
||||
|
||||
|
||||
class RequestBodyError(HttpRequestNodeError):
|
||||
"""Raised when the request body is invalid."""
|
||||
|
||||
@ -23,6 +23,7 @@ from .exc import (
|
||||
FileFetchError,
|
||||
HttpRequestNodeError,
|
||||
InvalidHttpMethodError,
|
||||
RequestBodyError,
|
||||
ResponseSizeError,
|
||||
)
|
||||
|
||||
@ -143,13 +144,19 @@ class Executor:
|
||||
case "none":
|
||||
self.content = ""
|
||||
case "raw-text":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("raw-text body type should have exactly one item")
|
||||
self.content = self.variable_pool.convert_template(data[0].value).text
|
||||
case "json":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("json body type should have exactly one item")
|
||||
json_string = self.variable_pool.convert_template(data[0].value).text
|
||||
json_object = json.loads(json_string, strict=False)
|
||||
self.json = json_object
|
||||
# self.json = self._parse_object_contains_variables(json_object)
|
||||
case "binary":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("binary body type should have exactly one item")
|
||||
file_selector = data[0].file
|
||||
file_variable = self.variable_pool.get_file(file_selector)
|
||||
if file_variable is None:
|
||||
@ -249,7 +256,7 @@ class Executor:
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e))
|
||||
return response
|
||||
|
||||
@ -317,6 +324,8 @@ class Executor:
|
||||
elif self.json:
|
||||
body = json.dumps(self.json)
|
||||
elif self.node_data.body.type == "raw-text":
|
||||
if len(self.node_data.body.data) != 1:
|
||||
raise RequestBodyError("raw-text body type should have exactly one item")
|
||||
body = self.node_data.body.data[0].value
|
||||
if body:
|
||||
raw += f"Content-Length: {len(body)}\r\n"
|
||||
|
||||
@ -20,7 +20,7 @@ from .entities import (
|
||||
HttpRequestNodeTimeout,
|
||||
Response,
|
||||
)
|
||||
from .exc import HttpRequestNodeError
|
||||
from .exc import HttpRequestNodeError, RequestBodyError
|
||||
|
||||
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
@ -136,9 +136,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
data = node_data.body.data
|
||||
match body_type:
|
||||
case "binary":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("invalid body data, should have only one item")
|
||||
selector = data[0].file
|
||||
selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector))
|
||||
case "json" | "raw-text":
|
||||
if len(data) != 1:
|
||||
raise RequestBodyError("invalid body data, should have only one item")
|
||||
selectors += variable_template_parser.extract_selectors_from_template(data[0].key)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(data[0].value)
|
||||
case "x-www-form-urlencoded":
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import IntegerVariable
|
||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
NodeRunResult,
|
||||
@ -75,12 +75,15 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
|
||||
if not iterator_list_segment:
|
||||
raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
|
||||
|
||||
if len(iterator_list_segment.value) == 0:
|
||||
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -89,7 +92,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
iterator_list_value = iterator_list_segment.to_object()
|
||||
iterator_list_value = variable.to_object()
|
||||
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
|
||||
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
text: str = ""
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
query = query_variable.text
|
||||
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
user_query=query,
|
||||
user_files=files,
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
@ -545,8 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
def _fetch_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
user_query: str | None = None,
|
||||
user_files: Sequence["File"],
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence["File"],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@ -562,7 +562,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(
|
||||
_handle_list_messages(
|
||||
self._handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
@ -581,14 +581,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
prompt_messages.extend(memory_messages)
|
||||
|
||||
# Add current query to the prompt messages
|
||||
if user_query:
|
||||
if sys_query:
|
||||
message = LLMNodeChatModelMessage(
|
||||
text=user_query,
|
||||
text=sys_query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
prompt_messages.extend(
|
||||
_handle_list_messages(
|
||||
self._handle_list_messages(
|
||||
messages=[message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
@ -635,24 +635,27 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
raise ValueError("Invalid prompt content type")
|
||||
|
||||
# Add current query to the prompt message
|
||||
if user_query:
|
||||
if sys_query:
|
||||
if prompt_content_type == str:
|
||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
|
||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif prompt_content_type == list:
|
||||
for content_item in prompt_content:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
content_item.data = user_query + "\n" + content_item.data
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
else:
|
||||
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
||||
|
||||
if vision_enabled and user_files:
|
||||
# The sys_files will be deprecated later
|
||||
if vision_enabled and sys_files:
|
||||
file_prompts = []
|
||||
for file in user_files:
|
||||
for file in sys_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
# If last prompt is a user prompt, add files into its contents,
|
||||
# otherwise append a new user prompt
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
@ -662,7 +665,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Filter prompt messages
|
||||
# Remove empty messages and filter unsupported content
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message.content, list):
|
||||
@ -846,6 +849,68 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
},
|
||||
}
|
||||
|
||||
def _handle_list_messages(
|
||||
self,
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
elif isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
||||
match role:
|
||||
@ -880,68 +945,6 @@ def _render_jinja2_message(
|
||||
return result_text
|
||||
|
||||
|
||||
def _handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
if isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
|
||||
@ -179,6 +179,15 @@ class ParameterExtractorNode(LLMNode):
|
||||
error=str(e),
|
||||
metadata={},
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
|
||||
error=str(e),
|
||||
metadata={},
|
||||
)
|
||||
|
||||
error = None
|
||||
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
@ -86,37 +84,38 @@ class QuestionClassifierNode(LLMNode):
|
||||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
user_query=query,
|
||||
sys_query=query,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
user_files=files,
|
||||
sys_files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
break
|
||||
|
||||
category_name = node_data.classes[0].name
|
||||
category_id = node_data.classes[0].id
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
break
|
||||
|
||||
category_name = node_data.classes[0].name
|
||||
category_id = node_data.classes[0].id
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||
if "category_name" in result_text_json and "category_id" in result_text_json:
|
||||
@ -127,10 +126,6 @@ class QuestionClassifierNode(LLMNode):
|
||||
if category_id_result in category_ids:
|
||||
category_name = classes_map[category_id_result]
|
||||
category_id = category_id_result
|
||||
|
||||
except OutputParserError:
|
||||
logging.exception(f"Failed to parse result text: {result_text}")
|
||||
try:
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
@ -154,7 +149,6 @@ class QuestionClassifierNode(LLMNode):
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -231,6 +232,10 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
url = str(response.message)
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
tool_file_id = url.split("/")[-1].split(".")[0]
|
||||
try:
|
||||
UUID(tool_file_id)
|
||||
except ValueError:
|
||||
raise ToolFileError(f"cannot extract tool file id from url {url}")
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
class VariableOperatorNodeError(Exception):
|
||||
class VariableOperatorNodeError(ValueError):
|
||||
"""Base error type, don't use directly."""
|
||||
|
||||
pass
|
||||
|
||||
@ -1,18 +1,5 @@
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
from dify_app import DifyApp
|
||||
|
||||
POSTGRES_INDEXES_NAMING_CONVENTION = {
|
||||
"ix": "%(column_0_label)s_idx",
|
||||
"uq": "%(table_name)s_%(column_0_name)s_key",
|
||||
"ck": "%(table_name)s_%(constraint_name)s_check",
|
||||
"fk": "%(table_name)s_%(column_0_name)s_fkey",
|
||||
"pk": "%(table_name)s_pkey",
|
||||
}
|
||||
|
||||
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
|
||||
db = SQLAlchemy(metadata=metadata)
|
||||
from models import db
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
|
||||
@ -3,4 +3,3 @@ from dify_app import DifyApp
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from events import event_handlers # noqa: F401
|
||||
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401
|
||||
|
||||
@ -67,7 +67,9 @@ class AwsS3Storage(BaseStorage):
|
||||
yield from response["Body"].iter_chunks()
|
||||
except ClientError as ex:
|
||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||
raise FileNotFoundError("File not found")
|
||||
raise FileNotFoundError("file not found")
|
||||
elif "reached max retries" in str(ex):
|
||||
raise ValueError("please do not request the same file too frequently")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
@ -116,8 +116,11 @@ def _build_from_local_file(
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
) -> File:
|
||||
upload_file_id = mapping.get("upload_file_id")
|
||||
if not upload_file_id:
|
||||
raise ValueError("Invalid upload file id")
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == mapping.get("upload_file_id"),
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
|
||||
@ -139,6 +142,7 @@ def _build_from_local_file(
|
||||
remote_url=row.source_url,
|
||||
related_id=mapping.get("upload_file_id"),
|
||||
size=row.size,
|
||||
storage_key=row.key,
|
||||
)
|
||||
|
||||
|
||||
@ -168,6 +172,7 @@ def _build_from_remote_url(
|
||||
mime_type=mime_type,
|
||||
extension=extension,
|
||||
size=file_size,
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
|
||||
@ -220,6 +225,7 @@ def _build_from_tool_file(
|
||||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ message_detail_fields = {
|
||||
}
|
||||
|
||||
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
|
||||
|
||||
status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer}
|
||||
model_config_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
@ -166,6 +166,7 @@ conversation_with_summary_fields = {
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
"status_count": fields.Nested(status_count_fields),
|
||||
}
|
||||
|
||||
conversation_with_summary_pagination_fields = {
|
||||
|
||||
@ -82,13 +82,15 @@ workflow_run_detail_fields = {
|
||||
}
|
||||
|
||||
retry_event_field = {
|
||||
"elapsed_time": fields.Float,
|
||||
"status": fields.String,
|
||||
"inputs": fields.Raw(attribute="inputs"),
|
||||
"process_data": fields.Raw(attribute="process_data"),
|
||||
"outputs": fields.Raw(attribute="outputs"),
|
||||
"metadata": fields.Raw(attribute="metadata"),
|
||||
"llm_usage": fields.Raw(attribute="llm_usage"),
|
||||
"error": fields.String,
|
||||
"retry_index": fields.Integer,
|
||||
"inputs": fields.Raw(attribute="inputs"),
|
||||
"elapsed_time": fields.Float,
|
||||
"execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"outputs": fields.Raw(attribute="outputs"),
|
||||
}
|
||||
|
||||
|
||||
@ -112,7 +114,6 @@ workflow_run_node_execution_fields = {
|
||||
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
|
||||
"finished_at": TimestampField,
|
||||
"retry_events": fields.List(fields.Nested(retry_event_field)),
|
||||
}
|
||||
|
||||
workflow_run_node_execution_list_fields = {
|
||||
|
||||
@ -13,7 +13,7 @@ from typing import Any, Optional, Union, cast
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import fields # type: ignore
|
||||
from flask_restful import fields
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
|
||||
@ -27,7 +27,7 @@ def parse_json_markdown(json_string: str) -> dict:
|
||||
extracted_content = json_string[start_index:end_index].strip()
|
||||
parsed = json.loads(extracted_content)
|
||||
else:
|
||||
raise Exception("Could not find JSON block in the output.")
|
||||
raise ValueError("could not find json block in the output.")
|
||||
|
||||
return parsed
|
||||
|
||||
@ -36,10 +36,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserError(f"Got invalid JSON object. Error: {e}")
|
||||
raise OutputParserError(f"got invalid json object. error: {e}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserError(
|
||||
f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}"
|
||||
f"got invalid return object. expected key `{key}` to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
||||
|
||||
@ -1,33 +0,0 @@
|
||||
"""add retry_index field to node-execution model
|
||||
|
||||
Revision ID: 348cb0a93d53
|
||||
Revises: cf8f4fc45278
|
||||
Create Date: 2024-12-16 01:23:13.093432
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '348cb0a93d53'
|
||||
down_revision = 'cf8f4fc45278'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.drop_column('retry_index')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,39 @@
|
||||
"""remove unused tool_providers
|
||||
|
||||
Revision ID: 11b07f66c737
|
||||
Revises: cf8f4fc45278
|
||||
Create Date: 2024-12-19 17:46:25.780116
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '11b07f66c737'
|
||||
down_revision = 'cf8f4fc45278'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('tool_providers')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tool_providers',
|
||||
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
|
||||
sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
|
||||
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
|
||||
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,37 @@
|
||||
"""add retry_index field to node-execution model
|
||||
Revision ID: e1944c35e15e
|
||||
Revises: 11b07f66c737
|
||||
Create Date: 2024-12-20 06:28:30.287197
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e1944c35e15e'
|
||||
down_revision = '11b07f66c737'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# We don't need these fields anymore, but this file is already merged into the main branch,
|
||||
# so we need to keep this file for the sake of history, and this change will be reverted in the next migration.
|
||||
# with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
# batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
|
||||
|
||||
pass
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
# batch_op.drop_column('retry_index')
|
||||
pass
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,34 @@
|
||||
"""remove workflow_node_executions.retry_index if exists
|
||||
|
||||
Revision ID: d7999dfa4aae
|
||||
Revises: e1944c35e15e
|
||||
Create Date: 2024-12-23 11:54:15.344543
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd7999dfa4aae'
|
||||
down_revision = 'e1944c35e15e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Check if column exists before attempting to remove it
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
has_column = 'retry_index' in [col['name'] for col in inspector.get_columns('workflow_node_executions')]
|
||||
|
||||
if has_column:
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.drop_column('retry_index')
|
||||
|
||||
|
||||
def downgrade():
|
||||
# No downgrade needed as we don't want to restore the column
|
||||
pass
|
||||
@ -1,53 +1,187 @@
|
||||
from .account import Account, AccountIntegrate, InvitationCode, Tenant
|
||||
from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
|
||||
from .account import (
|
||||
Account,
|
||||
AccountIntegrate,
|
||||
AccountStatus,
|
||||
InvitationCode,
|
||||
Tenant,
|
||||
TenantAccountJoin,
|
||||
TenantAccountJoinRole,
|
||||
TenantAccountRole,
|
||||
TenantStatus,
|
||||
)
|
||||
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from .dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
DatasetCollectionBinding,
|
||||
DatasetKeywordTable,
|
||||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
DatasetProcessRule,
|
||||
DatasetQuery,
|
||||
Document,
|
||||
DocumentSegment,
|
||||
Embedding,
|
||||
ExternalKnowledgeApis,
|
||||
ExternalKnowledgeBindings,
|
||||
TidbAuthBinding,
|
||||
Whitelist,
|
||||
)
|
||||
from .engine import db
|
||||
from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom
|
||||
from .model import (
|
||||
ApiRequest,
|
||||
ApiToken,
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
AppAnnotationSetting,
|
||||
AppMode,
|
||||
AppModelConfig,
|
||||
Conversation,
|
||||
DatasetRetrieverResource,
|
||||
DifySetup,
|
||||
EndUser,
|
||||
IconType,
|
||||
InstalledApp,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
OperationLog,
|
||||
RecommendedApp,
|
||||
Site,
|
||||
Tag,
|
||||
TagBinding,
|
||||
TraceAppConfig,
|
||||
UploadFile,
|
||||
)
|
||||
from .source import DataSourceOauthBinding
|
||||
from .tools import ToolFile
|
||||
from .provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
ProviderModel,
|
||||
ProviderModelSetting,
|
||||
ProviderOrder,
|
||||
ProviderQuotaType,
|
||||
ProviderType,
|
||||
TenantDefaultModel,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from .task import CeleryTask, CeleryTaskSet
|
||||
from .tools import (
|
||||
ApiToolProvider,
|
||||
BuiltinToolProvider,
|
||||
PublishedAppTool,
|
||||
ToolConversationVariables,
|
||||
ToolFile,
|
||||
ToolLabelBinding,
|
||||
ToolModelInvoke,
|
||||
WorkflowToolProvider,
|
||||
)
|
||||
from .web import PinnedConversation, SavedMessage
|
||||
from .workflow import (
|
||||
ConversationVariable,
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"APIBasedExtension",
|
||||
"APIBasedExtensionPoint",
|
||||
"Account",
|
||||
"AccountIntegrate",
|
||||
"AccountStatus",
|
||||
"ApiRequest",
|
||||
"ApiToken",
|
||||
"ApiToolProvider", # Added
|
||||
"App",
|
||||
"AppAnnotationHitHistory",
|
||||
"AppAnnotationSetting",
|
||||
"AppDatasetJoin",
|
||||
"AppMode",
|
||||
"AppModelConfig",
|
||||
"BuiltinToolProvider", # Added
|
||||
"CeleryTask",
|
||||
"CeleryTaskSet",
|
||||
"Conversation",
|
||||
"ConversationVariable",
|
||||
"CreatedByRole",
|
||||
"DataSourceApiKeyAuthBinding",
|
||||
"DataSourceOauthBinding",
|
||||
"Dataset",
|
||||
"DatasetCollectionBinding",
|
||||
"DatasetKeywordTable",
|
||||
"DatasetPermission",
|
||||
"DatasetPermissionEnum",
|
||||
"DatasetProcessRule",
|
||||
"DatasetQuery",
|
||||
"DatasetRetrieverResource",
|
||||
"DifySetup",
|
||||
"Document",
|
||||
"DocumentSegment",
|
||||
"Embedding",
|
||||
"EndUser",
|
||||
"ExternalKnowledgeApis",
|
||||
"ExternalKnowledgeBindings",
|
||||
"IconType",
|
||||
"InstalledApp",
|
||||
"InvitationCode",
|
||||
"LoadBalancingModelConfig",
|
||||
"Message",
|
||||
"MessageAgentThought",
|
||||
"MessageAnnotation",
|
||||
"MessageChain",
|
||||
"MessageFeedback",
|
||||
"MessageFile",
|
||||
"OperationLog",
|
||||
"PinnedConversation",
|
||||
"Provider",
|
||||
"ProviderModel",
|
||||
"ProviderModelSetting",
|
||||
"ProviderOrder",
|
||||
"ProviderQuotaType",
|
||||
"ProviderType",
|
||||
"PublishedAppTool",
|
||||
"RecommendedApp",
|
||||
"SavedMessage",
|
||||
"Site",
|
||||
"Tag",
|
||||
"TagBinding",
|
||||
"Tenant",
|
||||
"TenantAccountJoin",
|
||||
"TenantAccountJoinRole",
|
||||
"TenantAccountRole",
|
||||
"TenantDefaultModel",
|
||||
"TenantPreferredModelProvider",
|
||||
"TenantStatus",
|
||||
"TidbAuthBinding",
|
||||
"ToolConversationVariables",
|
||||
"ToolFile",
|
||||
"ToolLabelBinding",
|
||||
"ToolModelInvoke",
|
||||
"TraceAppConfig",
|
||||
"UploadFile",
|
||||
"UserFrom",
|
||||
"Whitelist",
|
||||
"Workflow",
|
||||
"WorkflowAppLog",
|
||||
"WorkflowAppLogCreatedFrom",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowNodeExecutionStatus",
|
||||
"WorkflowNodeExecutionTriggeredFrom",
|
||||
"WorkflowRun",
|
||||
"WorkflowRunStatus",
|
||||
"WorkflowRunTriggeredFrom",
|
||||
"WorkflowToolProvider",
|
||||
"WorkflowType",
|
||||
"db",
|
||||
]
|
||||
|
||||
@ -2,9 +2,9 @@ import enum
|
||||
import json
|
||||
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import func
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
@ -31,11 +31,11 @@ class Account(UserMixin, db.Model):
|
||||
timezone = db.Column(db.String(255))
|
||||
last_login_at = db.Column(db.DateTime)
|
||||
last_login_ip = db.Column(db.String(255))
|
||||
last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying"))
|
||||
initialized_at = db.Column(db.DateTime)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
@ -99,11 +99,6 @@ class Account(UserMixin, db.Model):
|
||||
return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none()
|
||||
return None
|
||||
|
||||
def get_integrates(self) -> list[db.Model]:
|
||||
ai = db.Model
|
||||
return db.session.query(ai).filter(ai.account_id == self.id).all()
|
||||
|
||||
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
||||
@property
|
||||
def is_admin_or_owner(self):
|
||||
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
|
||||
@ -188,8 +183,8 @@ class Tenant(db.Model):
|
||||
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
custom_config = db.Column(db.Text)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def get_accounts(self) -> list[Account]:
|
||||
return (
|
||||
@ -229,8 +224,8 @@ class TenantAccountJoin(db.Model):
|
||||
current = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
role = db.Column(db.String(16), nullable=False, server_default="normal")
|
||||
invited_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AccountIntegrate(db.Model):
|
||||
@ -246,8 +241,8 @@ class AccountIntegrate(db.Model):
|
||||
provider = db.Column(db.String(16), nullable=False)
|
||||
open_id = db.Column(db.String(255), nullable=False)
|
||||
encrypted_token = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class InvitationCode(db.Model):
|
||||
@ -266,4 +261,4 @@ class InvitationCode(db.Model):
|
||||
used_by_tenant_id = db.Column(StringUUID)
|
||||
used_by_account_id = db.Column(StringUUID)
|
||||
deprecated_at = db.Column(db.DateTime)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import enum
|
||||
|
||||
from extensions.ext_database import db
|
||||
from sqlalchemy import func
|
||||
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
@ -24,4 +25,4 @@ class APIBasedExtension(db.Model):
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
api_endpoint = db.Column(db.String(255), nullable=False)
|
||||
api_key = db.Column(db.Text, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@ -15,10 +15,10 @@ from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from .account import Account
|
||||
from .engine import db
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import StringUUID
|
||||
|
||||
@ -50,9 +50,9 @@ class Dataset(db.Model):
|
||||
indexing_technique = db.Column(db.String(255), nullable=True)
|
||||
index_struct = db.Column(db.Text, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
embedding_model = db.Column(db.String(255), nullable=True)
|
||||
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||
collection_binding_id = db.Column(StringUUID, nullable=True)
|
||||
@ -212,7 +212,7 @@ class DatasetProcessRule(db.Model):
|
||||
mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
|
||||
rules = db.Column(db.Text, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
MODES = ["automatic", "custom"]
|
||||
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
|
||||
@ -264,7 +264,7 @@ class Document(db.Model):
|
||||
created_from = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_api_request_id = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
# start processing
|
||||
processing_started_at = db.Column(db.DateTime, nullable=True)
|
||||
@ -303,7 +303,7 @@ class Document(db.Model):
|
||||
archived_reason = db.Column(db.String(255), nullable=True)
|
||||
archived_by = db.Column(StringUUID, nullable=True)
|
||||
archived_at = db.Column(db.DateTime, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
doc_type = db.Column(db.String(40), nullable=True)
|
||||
doc_metadata = db.Column(db.JSON, nullable=True)
|
||||
doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
||||
@ -527,9 +527,9 @@ class DocumentSegment(db.Model):
|
||||
disabled_by = db.Column(StringUUID, nullable=True)
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
indexing_at = db.Column(db.DateTime, nullable=True)
|
||||
completed_at = db.Column(db.DateTime, nullable=True)
|
||||
error = db.Column(db.Text, nullable=True)
|
||||
@ -697,7 +697,7 @@ class Embedding(db.Model):
|
||||
)
|
||||
hash = db.Column(db.String(64), nullable=False)
|
||||
embedding = db.Column(db.LargeBinary, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
|
||||
|
||||
def set_embedding(self, embedding_data: list[float]):
|
||||
@ -719,7 +719,7 @@ class DatasetCollectionBinding(db.Model):
|
||||
model_name = db.Column(db.String(255), nullable=False)
|
||||
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
|
||||
collection_name = db.Column(db.String(64), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TidbAuthBinding(db.Model):
|
||||
@ -739,7 +739,7 @@ class TidbAuthBinding(db.Model):
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
|
||||
account = db.Column(db.String(255), nullable=False)
|
||||
password = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class Whitelist(db.Model):
|
||||
@ -751,7 +751,7 @@ class Whitelist(db.Model):
|
||||
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=True)
|
||||
category = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetPermission(db.Model):
|
||||
@ -768,7 +768,7 @@ class DatasetPermission(db.Model):
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ExternalKnowledgeApis(db.Model):
|
||||
@ -785,9 +785,9 @@ class ExternalKnowledgeApis(db.Model):
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
settings = db.Column(db.Text, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
@ -840,6 +840,6 @@ class ExternalKnowledgeBindings(db.Model):
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
external_knowledge_id = db.Column(db.Text, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user