mirror of
https://github.com/langgenius/dify.git
synced 2026-02-18 00:56:19 +08:00
Compare commits
91 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e88ea71aef | |||
| e0f1410b48 | |||
| c3c85276d1 | |||
| af2888d394 | |||
| d0dd8b7955 | |||
| 75bce2822e | |||
| 425cc1ea85 | |||
| 2bf33c4dd2 | |||
| dc19cd5d9d | |||
| dfc25dbdd0 | |||
| c4091c4c66 | |||
| ef95b1268e | |||
| e068bbec73 | |||
| 9cfd1c67b6 | |||
| 8978a6a3ff | |||
| 4e3d732934 | |||
| 03548cdfbc | |||
| 70dd69d533 | |||
| 453f324f54 | |||
| c1aa55f3ea | |||
| 4b1e13e982 | |||
| 4584eb3058 | |||
| 02a7ae15f9 | |||
| 74b1b60125 | |||
| 39df994ff9 | |||
| d9875fe232 | |||
| 26c10b9931 | |||
| 750662eb08 | |||
| 6b49889041 | |||
| 03ddee3663 | |||
| 10caab1729 | |||
| c6a72def88 | |||
| 21a31d7f8b | |||
| 2c4df108e5 | |||
| 5db8addcc6 | |||
| dd0e81d094 | |||
| 90f093eb67 | |||
| a056a9d601 | |||
| 2ad2a402fb | |||
| 3d07a94bd7 | |||
| 366857cd26 | |||
| 9578246bbb | |||
| 9ee9e9c6de | |||
| e22cc28114 | |||
| a227af3664 | |||
| 599d410d99 | |||
| 5e37ab60d8 | |||
| 0b06235527 | |||
| b8d42cdea7 | |||
| 455791b710 | |||
| 90323cd355 | |||
| c07d9e96ce | |||
| 810adb8a94 | |||
| 606aadb891 | |||
| 8f73670925 | |||
| 8c559d6231 | |||
| 786cb6859b | |||
| de8800f41a | |||
| 7a00798027 | |||
| 6ded06c6d9 | |||
| f53741c5b9 | |||
| 2681bafb76 | |||
| ac635c70cd | |||
| ef7e47d162 | |||
| 4211b9abbd | |||
| 0c0120ef27 | |||
| dacd457478 | |||
| 7b03a0316d | |||
| 52201d95b1 | |||
| e2cde628bb | |||
| 3335fa78fc | |||
| 7abc7fa573 | |||
| f6247fe67c | |||
| 996a9135f6 | |||
| 3599751f93 | |||
| 2d186e1e76 | |||
| bb2f46d7cc | |||
| 463fbe2680 | |||
| 95a7e50137 | |||
| 9d93ad1f16 | |||
| 44104797d6 | |||
| 1548501050 | |||
| de3911e930 | |||
| 5a8a901560 | |||
| 12d45e9114 | |||
| d057067543 | |||
| 560d375e0f | |||
| 3388d6636c | |||
| 2624a6dcd0 | |||
| b5c2785e10 | |||
| 493834d45d |
3
.github/workflows/api-tests.yml
vendored
3
.github/workflows/api-tests.yml
vendored
@ -50,6 +50,9 @@ jobs:
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
|
||||
@ -399,6 +399,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# App configuration
|
||||
|
||||
@ -70,7 +70,6 @@ ignore = [
|
||||
"SIM113", # eumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"SIM300", # yoda-conditions,
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
||||
@ -555,7 +555,8 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
|
||||
if language not in languages:
|
||||
language = "en-US"
|
||||
|
||||
name = name.strip()
|
||||
# Validates name encoding for non-Latin characters.
|
||||
name = name.strip().encode("utf-8").decode("utf-8") if name else None
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
@ -433,6 +433,11 @@ class WorkflowConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
|
||||
description="Maximum allowed depth for nested parallel executions",
|
||||
default=3,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||
default=200 * 1024,
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.14.1",
|
||||
default="0.14.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -4,3 +4,8 @@ from werkzeug.exceptions import HTTPException
|
||||
class FilenameNotExistsError(HTTPException):
|
||||
code = 400
|
||||
description = "The specified filename does not exist."
|
||||
|
||||
|
||||
class RemoteFileUploadError(HTTPException):
|
||||
code = 400
|
||||
description = "Error uploading remote file."
|
||||
|
||||
@ -31,7 +31,7 @@ def admin_required(view):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if dify_config.ADMIN_API_KEY != auth_token:
|
||||
if auth_token != dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -6,6 +6,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@ -426,7 +427,21 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
|
||||
@ -5,8 +5,7 @@ from typing import Optional, Union
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||
|
||||
@ -76,7 +76,7 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource):
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ app_fields = {
|
||||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import RemoteFileUploadError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
@ -43,10 +44,14 @@ class RemoteFileUploadApi(Resource):
|
||||
|
||||
url = args["url"]
|
||||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
|
||||
@ -3,12 +3,14 @@ import io
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return BuiltinToolManageService.update_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
args["credentials"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
|
||||
class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
@ -7,6 +8,7 @@ from controllers.service_api import api
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
conversation_delete_fields,
|
||||
conversation_infinite_scroll_pagination_fields,
|
||||
@ -39,14 +41,16 @@ class ConversationApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -104,10 +104,11 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
parser.add_argument("content", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource):
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from flask_restful import marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import RemoteFileUploadError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
@ -38,10 +39,14 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
|
||||
url = args["url"]
|
||||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
|
||||
@ -4,14 +4,17 @@ import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
QueueAgentMessageEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
@ -21,7 +24,7 @@ class AudioTrunk:
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
)
|
||||
|
||||
|
||||
def _process_future(future_queue, audio_queue):
|
||||
def _process_future(
|
||||
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
|
||||
audio_queue: queue.Queue[AudioTrunk],
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
future = future_queue.get()
|
||||
if future is None:
|
||||
break
|
||||
for audio in future.result():
|
||||
invoke_result = future.result()
|
||||
if not invoke_result:
|
||||
continue
|
||||
for audio in invoke_result:
|
||||
audio_base64 = base64.b64encode(bytes(audio))
|
||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||
except Exception as e:
|
||||
@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ""
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
|
||||
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message):
|
||||
try:
|
||||
self._msg_queue.put(message)
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
||||
self._msg_queue.put(message)
|
||||
|
||||
def _runtime(self):
|
||||
future_queue = queue.Queue()
|
||||
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
|
||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||
while True:
|
||||
try:
|
||||
@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def check_and_get_audio(self) -> AudioTrunk | None:
|
||||
def check_and_get_audio(self):
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
|
||||
@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@ -179,7 +180,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
@ -196,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
@ -221,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -290,9 +291,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
if not workflow_run:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
@ -330,47 +349,48 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
if response:
|
||||
yield response
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
@ -389,10 +409,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
@ -412,10 +432,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
@ -494,7 +514,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(
|
||||
@ -505,7 +525,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@ -11,9 +10,11 @@ from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueErrorEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -37,11 +38,11 @@ class AppQueueManager:
|
||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
|
||||
q = queue.Queue()
|
||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
|
||||
def listen(self) -> Generator:
|
||||
def listen(self):
|
||||
"""
|
||||
Listen to queue
|
||||
:return:
|
||||
|
||||
@ -18,6 +18,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@ -154,7 +155,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
@ -170,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
@ -195,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -217,7 +218,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
break
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
@ -253,9 +254,27 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
if not workflow_run:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
@ -286,50 +305,50 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
@ -349,10 +368,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
@ -373,10 +392,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -404,7 +423,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(
|
||||
|
||||
@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@ -186,6 +188,41 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error=event.error,
|
||||
execution_metadata=execution_metadata,
|
||||
retry_index=event.retry_index,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
@ -205,6 +242,17 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
@ -216,18 +264,10 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
@ -43,6 +44,7 @@ class QueueEvent(StrEnum):
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
@ -84,9 +86,9 @@ class QueueIterationStartEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
@ -138,9 +140,9 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
@ -303,9 +305,9 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
@ -313,6 +315,20 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
"""QueueNodeRetryEvent entity"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRY
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
@ -336,10 +352,10 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@ -367,10 +383,10 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@ -398,10 +414,10 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
@ -52,6 +52,7 @@ class StreamEvent(Enum):
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
}
|
||||
|
||||
|
||||
class NodeRetryStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"data": {
|
||||
"id": self.data.id,
|
||||
"node_id": self.data.node_id,
|
||||
"node_type": self.data.node_type,
|
||||
"title": self.data.title,
|
||||
"index": self.data.index,
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"process_data": None,
|
||||
"outputs": None,
|
||||
"status": self.data.status,
|
||||
"error": None,
|
||||
"elapsed_time": self.data.elapsed_time,
|
||||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
|
||||
@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@ -26,6 +27,7 @@ from core.app.entities.task_entities import (
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
@ -423,6 +425,59 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
created_at = event.start_at
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
origin_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
}
|
||||
merged_metadata = (
|
||||
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
|
||||
if event.execution_metadata is not None
|
||||
else origin_metadata
|
||||
)
|
||||
execution_metadata = json.dumps(merged_metadata)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = created_at
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
@ -457,6 +512,12 @@ class WorkflowCycleManage:
|
||||
:param workflow_run: workflow run
|
||||
:return:
|
||||
"""
|
||||
# Attach WorkflowRun to an active session so "created_by_role" can be accessed.
|
||||
workflow_run = db.session.merge(workflow_run)
|
||||
|
||||
# Refresh to ensure any expired attributes are fully loaded
|
||||
db.session.refresh(workflow_run)
|
||||
|
||||
created_by = None
|
||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value:
|
||||
created_by_account = workflow_run.created_by_account
|
||||
@ -587,6 +648,51 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_retry_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
class LLMError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError):
|
||||
description = "Bad Request"
|
||||
|
||||
|
||||
class ProviderTokenNotInitError(Exception):
|
||||
class ProviderTokenNotInitError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
||||
"""
|
||||
@ -27,7 +27,7 @@ class ProviderTokenNotInitError(Exception):
|
||||
self.description = args[0] if args else self.description
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
class QuotaExceededError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the quota for a provider has been exceeded.
|
||||
"""
|
||||
@ -35,7 +35,7 @@ class QuotaExceededError(Exception):
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class AppInvokeQuotaExceededError(Exception):
|
||||
class AppInvokeQuotaExceededError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the quota for an app has been exceeded.
|
||||
"""
|
||||
@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception):
|
||||
description = "App Invoke Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
class ModelCurrentlyNotSupportError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
"""
|
||||
@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception):
|
||||
description = "Model Currently Not Support"
|
||||
|
||||
|
||||
class InvokeRateLimitError(Exception):
|
||||
class InvokeRateLimitError(ValueError):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
description = "Rate Limit Error"
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
import base64
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from . import helpers
|
||||
@ -41,7 +40,7 @@ def to_prompt_message_content(
|
||||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
) -> MultiModalPromptMessageContent:
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
@ -70,16 +69,13 @@ def to_prompt_message_content(
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
return _download_file_content(tool_file.file_key)
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
return _download_file_content(upload_file.key)
|
||||
# remote file
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
|
||||
return _download_file_content(f._storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _download_file_content(path: str, /):
|
||||
@ -110,11 +106,9 @@ def _get_encoded_string(f: File, /):
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
data = _download_file_content(f._storage_key)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
data = _download_file_content(f._storage_key)
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
||||
@ -1,32 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .models import File
|
||||
|
||||
|
||||
def get_upload_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(UploadFile).filter(
|
||||
UploadFile.id == file.related_id,
|
||||
UploadFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"upload file {file.related_id} not found")
|
||||
return record
|
||||
|
||||
|
||||
def get_tool_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(ToolFile).filter(
|
||||
ToolFile.id == file.related_id,
|
||||
ToolFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"tool file {file.related_id} not found")
|
||||
return record
|
||||
@ -47,6 +47,38 @@ class File(BaseModel):
|
||||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
|
||||
# Those properties are private, should not be exposed to the outside.
|
||||
_storage_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: Optional[str] = None,
|
||||
tenant_id: str,
|
||||
type: FileType,
|
||||
transfer_method: FileTransferMethod,
|
||||
remote_url: Optional[str] = None,
|
||||
related_id: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
extension: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
size: int = -1,
|
||||
storage_key: str,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
tenant_id=tenant_id,
|
||||
type=type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=remote_url,
|
||||
related_id=related_id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
)
|
||||
self._storage_key = storage_key
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
return {
|
||||
|
||||
@ -118,7 +118,7 @@ class CodeExecutor:
|
||||
return response.data.stdout or ""
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict:
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]):
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
|
||||
@ -25,7 +25,7 @@ class TemplateTransformer(ABC):
|
||||
return runner_script, preload_script
|
||||
|
||||
@classmethod
|
||||
def extract_result_str_from_response(cls, response: str) -> str:
|
||||
def extract_result_str_from_response(cls, response: str):
|
||||
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
|
||||
if not result:
|
||||
raise ValueError("Failed to parse result")
|
||||
@ -33,13 +33,21 @@ class TemplateTransformer(ABC):
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
def transform_response(cls, response: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Transform response to dict
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return json.loads(cls.extract_result_str_from_response(response))
|
||||
try:
|
||||
result = json.loads(cls.extract_result_str_from_response(response))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("failed to parse response")
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("result must be a dict")
|
||||
if not all(isinstance(k, str) for k in result):
|
||||
raise ValueError("result keys must be strings")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import base64
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
|
||||
|
||||
@ -14,6 +13,7 @@ def obfuscated_token(token: str):
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
|
||||
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
|
||||
@ -24,7 +24,7 @@ BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
class MaxRetriesExceededError(Exception):
|
||||
class MaxRetriesExceededError(ValueError):
|
||||
"""Raised when the maximum number of retries is exceeded."""
|
||||
|
||||
pass
|
||||
@ -65,11 +65,12 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
|
||||
if max_retries == 0:
|
||||
raise
|
||||
|
||||
retries += 1
|
||||
if retries <= max_retries:
|
||||
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||
|
||||
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
class OutputParserError(Exception):
|
||||
class OutputParserError(ValueError):
|
||||
pass
|
||||
|
||||
@ -4,6 +4,7 @@ from .message_entities import (
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
@ -27,6 +28,7 @@ __all__ = [
|
||||
"LLMResultChunkDelta",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"MultiModalPromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
|
||||
@ -84,10 +84,10 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
format: str = Field(..., description="the format of multi-modal file")
|
||||
base64_data: str = Field("", description="the base64 data of multi-modal file")
|
||||
url: str = Field("", description="the url of multi-modal file")
|
||||
mime_type: str = Field(..., description="the mime type of multi-modal file")
|
||||
format: str = Field(default=..., description="the format of multi-modal file")
|
||||
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
||||
url: str = Field(default="", description="the url of multi-modal file")
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InvokeError(Exception):
|
||||
class InvokeError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
class CredentialsValidateFailedError(Exception):
|
||||
class CredentialsValidateFailedError(ValueError):
|
||||
"""
|
||||
Credentials validate failed error
|
||||
"""
|
||||
|
||||
@ -531,7 +531,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
"data": message_content.base64_data,
|
||||
},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
@ -819,6 +819,82 @@ LLM_BASE_MODELS = [
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4o-2024-11-20",
|
||||
entity=AIModelEntity(
|
||||
model="fake-deployment-name",
|
||||
label=I18nObject(
|
||||
en_US="fake-deployment-name-label",
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="presence_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name="frequency_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=16384),
|
||||
ParameterRule(
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
type="string",
|
||||
help=I18nObject(
|
||||
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
|
||||
),
|
||||
required=False,
|
||||
options=["text", "json_object", "json_schema"],
|
||||
),
|
||||
ParameterRule(
|
||||
name="json_schema",
|
||||
label=I18nObject(en_US="JSON Schema"),
|
||||
type="text",
|
||||
help=I18nObject(
|
||||
zh_Hans="设置返回的json schema,llm将按照它返回",
|
||||
en_US="Set a response json schema will ensure LLM to adhere it.",
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=5.00,
|
||||
output=15.00,
|
||||
unit=0.000001,
|
||||
currency="USD",
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4-turbo",
|
||||
entity=AIModelEntity(
|
||||
|
||||
@ -171,6 +171,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-11-20
|
||||
value: gpt-4o-2024-11-20
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
|
||||
@ -92,7 +92,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@ -1,11 +1,19 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
|
||||
|
||||
def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
|
||||
region_name = credentials.get("aws_region")
|
||||
if not region_name:
|
||||
raise InvokeBadRequestError("aws_region is required")
|
||||
client_config = Config(region_name=region_name)
|
||||
aws_access_key_id = credentials.get("aws_access_key_id")
|
||||
aws_secret_access_key = credentials.get("aws_secret_access_key")
|
||||
|
||||
def get_bedrock_client(service_name, credentials=None):
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
aws_access_key_id = credentials["aws_access_key_id"]
|
||||
aws_secret_access_key = credentials["aws_secret_access_key"]
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
# use aksk to call bedrock
|
||||
client = boto3.client(
|
||||
|
||||
@ -62,7 +62,10 @@ class BedrockRerankModel(RerankModel):
|
||||
}
|
||||
)
|
||||
modelId = model
|
||||
region = credentials["aws_region"]
|
||||
region = credentials.get("aws_region")
|
||||
# region is a required field
|
||||
if not region:
|
||||
raise InvokeBadRequestError("aws_region is required in credentials")
|
||||
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
|
||||
@ -88,7 +88,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-1.5-pro
|
||||
- gemini-1.5-pro-latest
|
||||
- gemini-1.5-pro-001
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-1219
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 1219
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -21,6 +21,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@ -143,7 +144,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
ping_message = UserPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
|
||||
|
||||
except Exception as ex:
|
||||
@ -187,17 +188,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
genai.configure(api_key=credentials["google_api_key"])
|
||||
google_model = genai.GenerativeModel(model_name=model)
|
||||
|
||||
history = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
elif content["role"] == "system":
|
||||
system_instruction = content["parts"][0]
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
if not history:
|
||||
raise InvokeError("The user prompt message is required. You only add a system prompt message.")
|
||||
|
||||
google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction)
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
||||
@ -404,7 +411,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
return glm_content
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
return {"role": "user", "parts": [to_part(message.content)]}
|
||||
if isinstance(message.content, list):
|
||||
text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
|
||||
message.content = "".join(c.data for c in text_contents)
|
||||
return {"role": "system", "parts": [to_part(message.content)]}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
return {
|
||||
"role": "function",
|
||||
|
||||
@ -421,7 +421,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
|
||||
# text completion model
|
||||
response = client.completions.create(
|
||||
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -593,6 +597,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
model_parameters["response_format"] = {"type": response_format}
|
||||
elif "json_schema" in model_parameters:
|
||||
del model_parameters["json_schema"]
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
|
||||
@ -97,7 +97,10 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@ -119,7 +119,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
embeddings.append(result[0].get("embedding"))
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
elif "texts" == text_input_key:
|
||||
elif text_input_key == "texts":
|
||||
result = client.run(
|
||||
replicate_model_version,
|
||||
input={
|
||||
|
||||
@ -18,7 +18,7 @@ class SiliconflowProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials)
|
||||
model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2.5", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
||||
@ -100,7 +100,10 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
|
||||
@ -4,11 +4,10 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Optional, Union, cast
|
||||
|
||||
import google.auth.transport.requests
|
||||
import requests
|
||||
import vertexai.generative_models as glm
|
||||
from anthropic import AnthropicVertex, Stream
|
||||
from anthropic.types import (
|
||||
ContentBlockDeltaEvent,
|
||||
@ -19,8 +18,6 @@ from anthropic.types import (
|
||||
MessageStreamEvent,
|
||||
)
|
||||
from google.api_core import exceptions
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@ -47,6 +44,9 @@ from core.model_runtime.errors.invoke import (
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import vertexai.generative_models as glm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -102,6 +102,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
:param stream: is stream response
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
from google.oauth2 import service_account
|
||||
|
||||
# use Anthropic official SDK references
|
||||
# - https://github.com/anthropics/anthropic-sdk-python
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
@ -406,13 +408,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
|
||||
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool":
|
||||
"""
|
||||
Convert tool messages to glm tools
|
||||
|
||||
:param tools: tool messages
|
||||
:return: glm tools
|
||||
"""
|
||||
import vertexai.generative_models as glm
|
||||
|
||||
return glm.Tool(
|
||||
function_declarations=[
|
||||
glm.FunctionDeclaration(
|
||||
@ -473,6 +477,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
import vertexai.generative_models as glm
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
|
||||
config_kwargs = model_parameters.copy()
|
||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
||||
|
||||
@ -522,7 +530,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(
|
||||
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
|
||||
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
@ -554,7 +562,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
|
||||
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -638,13 +646,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content":
|
||||
"""
|
||||
Format a single message into glm.Content for Google API
|
||||
|
||||
:param message: one PromptMessage
|
||||
:return: glm Content representation of message
|
||||
"""
|
||||
import vertexai.generative_models as glm
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
glm_content = glm.Content(role="user", parts=[])
|
||||
|
||||
|
||||
@ -2,12 +2,9 @@ import base64
|
||||
import json
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import tiktoken
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
@ -24,6 +21,11 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
else:
|
||||
VertexTextEmbeddingModel = None
|
||||
|
||||
|
||||
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
"""
|
||||
@ -48,6 +50,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
@ -100,6 +106,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
from google.cloud import aiplatform
|
||||
from google.oauth2 import service_account
|
||||
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
|
||||
|
||||
try:
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
project_id = credentials["vertex_project_id"]
|
||||
|
||||
@ -40,6 +40,10 @@ configs: dict[str, ModelConfig] = {
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-pro-256k": ModelConfig(
|
||||
properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[],
|
||||
),
|
||||
"Doubao-pro-128k": ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
|
||||
@ -12,6 +12,7 @@ class ModelConfig(BaseModel):
|
||||
|
||||
ModelConfigs = {
|
||||
"Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)),
|
||||
"Doubao-embedding-large": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)),
|
||||
}
|
||||
|
||||
|
||||
@ -21,7 +22,7 @@ def get_model_config(credentials: dict) -> ModelConfig:
|
||||
if not model_configs:
|
||||
return ModelConfig(
|
||||
properties=ModelProperties(
|
||||
context_size=int(credentials.get("context_size", 0)),
|
||||
context_size=int(credentials.get("context_size", 4096)),
|
||||
max_chunks=int(credentials.get("max_chunks", 1)),
|
||||
)
|
||||
)
|
||||
|
||||
@ -166,6 +166,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-pro-256k
|
||||
value: Doubao-pro-256k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Llama3-8B
|
||||
value: Llama3-8B
|
||||
@ -220,6 +226,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: Doubao-embedding-large
|
||||
value: Doubao-embedding-large
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: Custom
|
||||
zh_Hans: 自定义
|
||||
|
||||
@ -355,7 +355,13 @@ class TraceTask:
|
||||
def conversation_trace(self, **kwargs):
|
||||
return kwargs
|
||||
|
||||
def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id):
|
||||
def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id):
|
||||
if not workflow_run:
|
||||
raise ValueError("Workflow run not found")
|
||||
|
||||
db.session.merge(workflow_run)
|
||||
db.sessoin.refresh(workflow_run)
|
||||
|
||||
workflow_id = workflow_run.workflow_id
|
||||
tenant_id = workflow_run.tenant_id
|
||||
workflow_run_id = workflow_run.id
|
||||
|
||||
@ -83,11 +83,15 @@ class DataPostProcessor:
|
||||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
reranking_provider_name = reranking_model.get("reranking_provider_name")
|
||||
reranking_model_name = reranking_model.get("reranking_model_name")
|
||||
if not reranking_provider_name or not reranking_model_name:
|
||||
return None
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
provider=reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model["reranking_model_name"],
|
||||
model=reranking_model_name,
|
||||
)
|
||||
return rerank_model_instance
|
||||
except InvokeAuthorizationError:
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import jieba
|
||||
from jieba.analyse import default_tfidf
|
||||
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
def __init__(self):
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
import jieba.analyse
|
||||
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
jieba.analyse.default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
import jieba
|
||||
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
@ -22,6 +23,8 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
|
||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
results = set()
|
||||
for token in tokens:
|
||||
results.add(token)
|
||||
|
||||
@ -103,7 +103,7 @@ class RetrievalService:
|
||||
|
||||
if exceptions:
|
||||
exception_message = ";\n".join(exceptions)
|
||||
raise Exception(exception_message)
|
||||
raise ValueError(exception_message)
|
||||
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(
|
||||
|
||||
@ -6,10 +6,8 @@ from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import jieba.posseg as pseg
|
||||
import nltk
|
||||
import numpy
|
||||
import oracledb
|
||||
from nltk.corpus import stopwords
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@ -202,6 +200,10 @@ class OracleVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# lazy import
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
@ -65,6 +65,11 @@ class CacheEmbedding(Embeddings):
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
|
||||
if np.isnan(normalized_embedding).any():
|
||||
# for issue #11827 float values are not json compliant
|
||||
logger.warning(f"Normalized embedding is nan: {normalized_embedding}")
|
||||
continue
|
||||
embedding_queue_embeddings.append(normalized_embedding)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
@ -111,6 +116,8 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
if np.isnan(embedding_results).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
|
||||
|
||||
@ -62,7 +62,7 @@ class AppToolProviderEntity(ToolProviderController):
|
||||
user_input_form_list = app_model_config.user_input_form_list
|
||||
for input_form in user_input_form_list:
|
||||
# get type
|
||||
form_type = input_form.keys()[0]
|
||||
form_type = list(input_form.keys())[0]
|
||||
default = input_form[form_type]["default"]
|
||||
required = input_form[form_type]["required"]
|
||||
label = input_form[form_type]["label"]
|
||||
|
||||
115
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Normal file
115
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Normal file
@ -0,0 +1,115 @@
|
||||
import json
|
||||
import operator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BedrockRetrieveTool(BuiltinTool):
|
||||
bedrock_client: Any = None
|
||||
knowledge_base_id: str = None
|
||||
topk: int = None
|
||||
|
||||
def _bedrock_retrieve(
|
||||
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
|
||||
):
|
||||
try:
|
||||
retrieval_query = {"text": query_input}
|
||||
|
||||
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
|
||||
|
||||
# 如果有元数据过滤条件,则添加到检索配置中
|
||||
if metadata_filter:
|
||||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
||||
|
||||
response = self.bedrock_client.retrieve(
|
||||
knowledgeBaseId=knowledge_base_id,
|
||||
retrievalQuery=retrieval_query,
|
||||
retrievalConfiguration=retrieval_configuration,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in response.get("retrievalResults", []):
|
||||
results.append(
|
||||
{
|
||||
"content": result.get("content", {}).get("text", ""),
|
||||
"score": result.get("score", 0.0),
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise Exception(f"Error retrieving from knowledge base: {str(e)}")
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime")
|
||||
|
||||
line = 1
|
||||
if not self.knowledge_base_id:
|
||||
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
|
||||
if not self.knowledge_base_id:
|
||||
return self.create_text_message("Please provide knowledge_base_id")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
# 获取元数据过滤条件(如果存在)
|
||||
metadata_filter_str = tool_parameters.get("metadata_filter")
|
||||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
||||
|
||||
line = 4
|
||||
retrieved_docs = self._bedrock_retrieve(
|
||||
query_input=query,
|
||||
knowledge_base_id=self.knowledge_base_id,
|
||||
num_results=self.topk,
|
||||
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
|
||||
)
|
||||
|
||||
line = 5
|
||||
# Sort results by score in descending order
|
||||
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 6
|
||||
return [self.create_json_message(res) for res in sorted_docs]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
||||
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the parameters
|
||||
"""
|
||||
if not parameters.get("knowledge_base_id"):
|
||||
raise ValueError("knowledge_base_id is required")
|
||||
|
||||
if not parameters.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
||||
metadata_filter_str = parameters.get("metadata_filter")
|
||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
@ -0,0 +1,87 @@
|
||||
identity:
|
||||
name: bedrock_retrieve
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Bedrock Retrieve
|
||||
zh_Hans: Bedrock检索
|
||||
pt_BR: Bedrock Retrieve
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明
|
||||
pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base.
|
||||
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
|
||||
parameters:
|
||||
- name: knowledge_base_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Bedrock Knowledge Base ID
|
||||
zh_Hans: Bedrock知识库ID
|
||||
pt_BR: Bedrock Knowledge Base ID
|
||||
human_description:
|
||||
en_US: ID of the Bedrock Knowledge Base to retrieve from
|
||||
zh_Hans: 用于检索的Bedrock知识库ID
|
||||
pt_BR: ID of the Bedrock Knowledge Base to retrieve from
|
||||
llm_description: ID of the Bedrock Knowledge Base to retrieve from
|
||||
form: form
|
||||
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: The search query to retrieve relevant information
|
||||
zh_Hans: 用于检索相关信息的查询语句
|
||||
pt_BR: The search query to retrieve relevant information
|
||||
llm_description: The search query to retrieve relevant information
|
||||
form: llm
|
||||
|
||||
- name: topk
|
||||
type: number
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Limit for results count
|
||||
zh_Hans: 返回结果数量限制
|
||||
pt_BR: Limit for results count
|
||||
human_description:
|
||||
en_US: Maximum number of results to return
|
||||
zh_Hans: 最大返回结果数量
|
||||
pt_BR: Maximum number of results to return
|
||||
min: 1
|
||||
max: 10
|
||||
default: 5
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
pt_BR: AWS Region
|
||||
human_description:
|
||||
en_US: AWS region where the Bedrock Knowledge Base is located
|
||||
zh_Hans: Bedrock知识库所在的AWS区域
|
||||
pt_BR: AWS region where the Bedrock Knowledge Base is located
|
||||
llm_description: AWS region where the Bedrock Knowledge Base is located
|
||||
form: form
|
||||
|
||||
- name: metadata_filter
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Metadata Filter
|
||||
zh_Hans: 元数据过滤器
|
||||
pt_BR: Metadata Filter
|
||||
human_description:
|
||||
en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
form: form
|
||||
357
api/core/tools/provider/builtin/aws/tools/nova_canvas.py
Normal file
357
api/core/tools/provider/builtin/aws/tools/nova_canvas.py
Normal file
@ -0,0 +1,357 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NovaCanvasTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke AWS Bedrock Nova Canvas model for image generation
|
||||
"""
|
||||
# Get common parameters
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip()
|
||||
if not prompt:
|
||||
return self.create_text_message("Please provide a text prompt for image generation.")
|
||||
if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3":
|
||||
return self.create_text_message("Please provide an valid S3 URI for image output.")
|
||||
|
||||
task_type = tool_parameters.get("task_type", "TEXT_IMAGE")
|
||||
aws_region = tool_parameters.get("aws_region", "us-east-1")
|
||||
|
||||
# Get common image generation config parameters
|
||||
width = tool_parameters.get("width", 1024)
|
||||
height = tool_parameters.get("height", 1024)
|
||||
cfg_scale = tool_parameters.get("cfg_scale", 8.0)
|
||||
negative_prompt = tool_parameters.get("negative_prompt", "")
|
||||
seed = tool_parameters.get("seed", 0)
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
|
||||
# Handle S3 image if provided
|
||||
image_input_s3uri = tool_parameters.get("image_input_s3uri", "")
|
||||
if task_type != "TEXT_IMAGE":
|
||||
if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3":
|
||||
return self.create_text_message("Please provide a valid S3 URI for image to image generation.")
|
||||
|
||||
# Parse S3 URI
|
||||
parsed_uri = urlparse(image_input_s3uri)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
# Initialize S3 client and download image
|
||||
s3_client = boto3.client("s3")
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
image_data = response["Body"].read()
|
||||
|
||||
# Base64 encode the image
|
||||
input_image = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
try:
|
||||
# Initialize Bedrock client
|
||||
bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region)
|
||||
|
||||
# Base image generation config
|
||||
image_generation_config = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"cfgScale": cfg_scale,
|
||||
"seed": seed,
|
||||
"numberOfImages": 1,
|
||||
"quality": quality,
|
||||
}
|
||||
|
||||
# Prepare request body based on task type
|
||||
body = {"imageGenerationConfig": image_generation_config}
|
||||
|
||||
if task_type == "TEXT_IMAGE":
|
||||
body["taskType"] = "TEXT_IMAGE"
|
||||
body["textToImageParams"] = {"text": prompt}
|
||||
if negative_prompt:
|
||||
body["textToImageParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "COLOR_GUIDED_GENERATION":
|
||||
colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680")
|
||||
if not self._validate_color_string(colors):
|
||||
return self.create_text_message("Please provide valid colors in hexadecimal format.")
|
||||
|
||||
body["taskType"] = "COLOR_GUIDED_GENERATION"
|
||||
body["colorGuidedGenerationParams"] = {
|
||||
"colors": colors.split("-"),
|
||||
"referenceImage": input_image,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "IMAGE_VARIATION":
|
||||
similarity_strength = tool_parameters.get("similarity_strength", 0.5)
|
||||
|
||||
body["taskType"] = "IMAGE_VARIATION"
|
||||
body["imageVariationParams"] = {
|
||||
"images": [input_image],
|
||||
"similarityStrength": similarity_strength,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["imageVariationParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "INPAINTING":
|
||||
mask_prompt = tool_parameters.get("mask_prompt")
|
||||
if not mask_prompt:
|
||||
return self.create_text_message("Please provide a mask prompt for image inpainting.")
|
||||
|
||||
body["taskType"] = "INPAINTING"
|
||||
body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt}
|
||||
if negative_prompt:
|
||||
body["inPaintingParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "OUTPAINTING":
|
||||
mask_prompt = tool_parameters.get("mask_prompt")
|
||||
if not mask_prompt:
|
||||
return self.create_text_message("Please provide a mask prompt for image outpainting.")
|
||||
outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT")
|
||||
|
||||
body["taskType"] = "OUTPAINTING"
|
||||
body["outPaintingParams"] = {
|
||||
"image": input_image,
|
||||
"maskPrompt": mask_prompt,
|
||||
"outPaintingMode": outpainting_mode,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["outPaintingParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "BACKGROUND_REMOVAL":
|
||||
body["taskType"] = "BACKGROUND_REMOVAL"
|
||||
body["backgroundRemovalParams"] = {"image": input_image}
|
||||
|
||||
else:
|
||||
return self.create_text_message(f"Unsupported task type: {task_type}")
|
||||
|
||||
# Call Nova Canvas model
|
||||
response = bedrock.invoke_model(
|
||||
body=json.dumps(body),
|
||||
modelId="amazon.nova-canvas-v1:0",
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
# Process response
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if response_body.get("error"):
|
||||
raise Exception(f"Error in model response: {response_body.get('error')}")
|
||||
base64_image = response_body.get("images")[0]
|
||||
|
||||
# Upload to S3 if image_output_s3uri is provided
|
||||
try:
|
||||
# Parse S3 URI for output
|
||||
parsed_uri = urlparse(image_output_s3uri)
|
||||
output_bucket = parsed_uri.netloc
|
||||
output_base_path = parsed_uri.path.lstrip("/")
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_key = f"{output_base_path}/canvas-output-{timestamp}.png"
|
||||
|
||||
# Initialize S3 client if not already done
|
||||
s3_client = boto3.client("s3", region_name=aws_region)
|
||||
|
||||
# Decode base64 image and upload to S3
|
||||
image_data = base64.b64decode(base64_image)
|
||||
s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png")
|
||||
logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to upload image to S3")
|
||||
# Return image
|
||||
return [
|
||||
self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"),
|
||||
self.create_blob_message(
|
||||
blob=base64.b64decode(base64_image),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
),
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to generate image: {str(e)}")
|
||||
|
||||
def _validate_color_string(self, color_string) -> bool:
|
||||
color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$"
|
||||
|
||||
if re.match(color_pattern, color_string):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description of the image you want to generate or modify",
|
||||
zh_Hans="您想要生成或修改的图像的文本描述",
|
||||
),
|
||||
llm_description="Describe the image you want to generate or how you want to modify the input image",
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_input_s3uri",
|
||||
label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_output_s3uri",
|
||||
label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="width",
|
||||
label=I18nObject(en_US="Width", zh_Hans="宽度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=1024,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="height",
|
||||
label=I18nObject(en_US="Height", zh_Hans="高度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=1024,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="cfg_scale",
|
||||
label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=8.0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="negative_prompt",
|
||||
label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="",
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="aws_region",
|
||||
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="us-east-1",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="task_type",
|
||||
label=I18nObject(en_US="Task Type", zh_Hans="任务类型"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="TEXT_IMAGE",
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="quality",
|
||||
label=I18nObject(en_US="Quality", zh_Hans="质量"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="standard",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="colors",
|
||||
label=I18nObject(en_US="Colors", zh_Hans="颜色"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680",
|
||||
zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="similarity_strength",
|
||||
label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0.5,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="How similar the generated image should be to the input image (0.0 to 1.0)",
|
||||
zh_Hans="生成的图像应该与输入图像的相似程度(0.0到1.0)",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="mask_prompt",
|
||||
label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description to generate mask for inpainting/outpainting",
|
||||
zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="outpainting_mode",
|
||||
label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="DEFAULT",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Mode for outpainting (DEFAULT or other supported modes)",
|
||||
zh_Hans="外补绘制的模式(DEFAULT或其他支持的模式)",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
return parameters
|
||||
175
api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml
Normal file
175
api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml
Normal file
@ -0,0 +1,175 @@
|
||||
identity:
|
||||
name: nova_canvas
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS Bedrock Nova Canvas
|
||||
zh_Hans: AWS Bedrock Nova Canvas
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html
|
||||
zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。
|
||||
llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal.
|
||||
parameters:
|
||||
- name: task_type
|
||||
type: string
|
||||
required: false
|
||||
default: TEXT_IMAGE
|
||||
label:
|
||||
en_US: Task Type
|
||||
zh_Hans: 任务类型
|
||||
human_description:
|
||||
en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL)
|
||||
zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除)
|
||||
form: llm
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Text description of the image you want to generate or modify
|
||||
zh_Hans: 您想要生成或修改的图像的文本描述
|
||||
llm_description: Describe the image you want to generate or how you want to modify the input image
|
||||
form: llm
|
||||
- name: image_input_s3uri
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Input image s3 uri
|
||||
zh_Hans: 输入图片的s3 uri
|
||||
human_description:
|
||||
en_US: The input image to modify (required for all modes except TEXT_IMAGE)
|
||||
zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要)
|
||||
llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE.
|
||||
form: llm
|
||||
- name: image_output_s3uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Output S3 URI
|
||||
zh_Hans: 输出S3 URI
|
||||
human_description:
|
||||
en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png
|
||||
zh_Hans: 生成的图像将保存到的S3 URI。如果提供,图像将以canvas-output-{timestamp}.png的格式上传
|
||||
llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename.
|
||||
form: form
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Negative Prompt
|
||||
zh_Hans: 负面提示词
|
||||
human_description:
|
||||
en_US: Things you don't want in the generated image
|
||||
zh_Hans: 您不想在生成的图像中出现的内容
|
||||
form: llm
|
||||
- name: width
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Width
|
||||
zh_Hans: 宽度
|
||||
human_description:
|
||||
en_US: Width of the generated image
|
||||
zh_Hans: 生成图像的宽度
|
||||
form: form
|
||||
default: 1024
|
||||
- name: height
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Height
|
||||
zh_Hans: 高度
|
||||
human_description:
|
||||
en_US: Height of the generated image
|
||||
zh_Hans: 生成图像的高度
|
||||
form: form
|
||||
default: 1024
|
||||
- name: cfg_scale
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: CFG Scale
|
||||
zh_Hans: CFG比例
|
||||
human_description:
|
||||
en_US: How strongly the image should conform to the prompt
|
||||
zh_Hans: 图像应该多大程度上符合提示词
|
||||
form: form
|
||||
default: 8.0
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子值
|
||||
human_description:
|
||||
en_US: Random seed for image generation
|
||||
zh_Hans: 图像生成的随机种子
|
||||
form: form
|
||||
default: 0
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
default: us-east-1
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: AWS region for Bedrock service
|
||||
zh_Hans: Bedrock 服务的 AWS 区域
|
||||
form: form
|
||||
- name: quality
|
||||
type: string
|
||||
required: false
|
||||
default: standard
|
||||
label:
|
||||
en_US: Quality
|
||||
zh_Hans: 质量
|
||||
human_description:
|
||||
en_US: Quality of the generated image (standard or premium)
|
||||
zh_Hans: 生成图像的质量(标准或高级)
|
||||
form: form
|
||||
- name: colors
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Colors
|
||||
zh_Hans: 颜色
|
||||
human_description:
|
||||
en_US: List of colors for color-guided generation
|
||||
zh_Hans: 颜色引导生成的颜色列表
|
||||
form: form
|
||||
- name: similarity_strength
|
||||
type: number
|
||||
required: false
|
||||
default: 0.5
|
||||
label:
|
||||
en_US: Similarity Strength
|
||||
zh_Hans: 相似度强度
|
||||
human_description:
|
||||
en_US: How similar the generated image should be to the input image (0.0 to 1.0)
|
||||
zh_Hans: 生成的图像应该与输入图像的相似程度(0.0到1.0)
|
||||
form: form
|
||||
- name: mask_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Mask Prompt
|
||||
zh_Hans: 蒙版提示词
|
||||
human_description:
|
||||
en_US: Text description to generate mask for inpainting/outpainting
|
||||
zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述
|
||||
form: llm
|
||||
- name: outpainting_mode
|
||||
type: string
|
||||
required: false
|
||||
default: DEFAULT
|
||||
label:
|
||||
en_US: Outpainting Mode
|
||||
zh_Hans: 外补绘制模式
|
||||
human_description:
|
||||
en_US: Mode for outpainting (DEFAULT or other supported modes)
|
||||
zh_Hans: 外补绘制的模式(DEFAULT或其他支持的模式)
|
||||
form: form
|
||||
371
api/core/tools/provider/builtin/aws/tools/nova_reel.py
Normal file
371
api/core/tools/provider/builtin/aws/tools/nova_reel.py
Normal file
@ -0,0 +1,371 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from PIL import Image
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOVA_REEL_DEFAULT_REGION = "us-east-1"
|
||||
NOVA_REEL_DEFAULT_DIMENSION = "1280x720"
|
||||
NOVA_REEL_DEFAULT_FPS = 24
|
||||
NOVA_REEL_DEFAULT_DURATION = 6
|
||||
NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0"
|
||||
NOVA_REEL_STATUS_CHECK_INTERVAL = 5
|
||||
|
||||
# Image requirements
|
||||
NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280
|
||||
NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720
|
||||
NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB"
|
||||
|
||||
|
||||
class NovaReelTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke AWS Bedrock Nova Reel model for video generation.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user making the request
|
||||
tool_parameters: Dictionary containing the tool parameters
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage containing either the video content or status information
|
||||
"""
|
||||
try:
|
||||
# Validate and extract parameters
|
||||
params = self._validate_and_extract_parameters(tool_parameters)
|
||||
if isinstance(params, ToolInvokeMessage):
|
||||
return params
|
||||
|
||||
# Initialize AWS clients
|
||||
bedrock, s3_client = self._initialize_aws_clients(params["aws_region"])
|
||||
|
||||
# Prepare model input
|
||||
model_input = self._prepare_model_input(params, s3_client)
|
||||
if isinstance(model_input, ToolInvokeMessage):
|
||||
return model_input
|
||||
|
||||
# Start video generation
|
||||
invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"])
|
||||
invocation_arn = invocation["invocationArn"]
|
||||
|
||||
# Handle async/sync mode
|
||||
return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"])
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||
error_message = e.response.get("Error", {}).get("Message", str(e))
|
||||
logger.exception(f"AWS API error: {error_code} - {error_message}")
|
||||
return self.create_text_message(f"AWS service error: {error_code} - {error_message}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(f"Failed to generate video: {str(e)}")
|
||||
|
||||
def _validate_and_extract_parameters(
|
||||
self, tool_parameters: dict[str, Any]
|
||||
) -> Union[dict[str, Any], ToolInvokeMessage]:
|
||||
"""Validate and extract parameters from the input dictionary."""
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip()
|
||||
|
||||
# Validate required parameters
|
||||
if not prompt:
|
||||
return self.create_text_message("Please provide a text prompt for video generation.")
|
||||
if not video_output_s3uri:
|
||||
return self.create_text_message("Please provide an S3 URI for video output.")
|
||||
|
||||
# Validate S3 URI format
|
||||
if not video_output_s3uri.startswith("s3://"):
|
||||
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
|
||||
|
||||
# Ensure S3 URI ends with '/'
|
||||
video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/"
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"video_output_s3uri": video_output_s3uri,
|
||||
"image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(),
|
||||
"aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION),
|
||||
"dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION),
|
||||
"seed": int(tool_parameters.get("seed", 0)),
|
||||
"fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)),
|
||||
"duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)),
|
||||
"async_mode": bool(tool_parameters.get("async", True)),
|
||||
}
|
||||
|
||||
def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]:
|
||||
"""Initialize AWS Bedrock and S3 clients."""
|
||||
bedrock = boto3.client(service_name="bedrock-runtime", region_name=region)
|
||||
s3_client = boto3.client("s3", region_name=region)
|
||||
return bedrock, s3_client
|
||||
|
||||
def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]:
|
||||
"""Prepare the input for the Nova Reel model."""
|
||||
model_input = {
|
||||
"taskType": "TEXT_VIDEO",
|
||||
"textToVideoParams": {"text": params["prompt"]},
|
||||
"videoGenerationConfig": {
|
||||
"durationSeconds": params["duration"],
|
||||
"fps": params["fps"],
|
||||
"dimension": params["dimension"],
|
||||
"seed": params["seed"],
|
||||
},
|
||||
}
|
||||
|
||||
# Add image if provided
|
||||
if params["image_input_s3uri"]:
|
||||
try:
|
||||
image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"])
|
||||
if not image_data:
|
||||
return self.create_text_message("Failed to retrieve image from S3")
|
||||
|
||||
# Process and validate image
|
||||
processed_image = self._process_and_validate_image(image_data)
|
||||
if isinstance(processed_image, ToolInvokeMessage):
|
||||
return processed_image
|
||||
|
||||
# Convert processed image to base64
|
||||
img_buffer = BytesIO()
|
||||
processed_image.save(img_buffer, format="PNG")
|
||||
img_buffer.seek(0)
|
||||
input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
|
||||
|
||||
model_input["textToVideoParams"]["images"] = [
|
||||
{"format": "png", "source": {"bytes": input_image_base64}}
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing input image: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(f"Failed to process input image: {str(e)}")
|
||||
|
||||
return model_input
|
||||
|
||||
def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]:
|
||||
"""
|
||||
Process and validate the input image according to Nova Reel requirements.
|
||||
|
||||
Requirements:
|
||||
- Must be 1280x720 pixels
|
||||
- Must be RGB format (8 bits per channel)
|
||||
- If PNG, alpha channel must not have transparent/translucent pixels
|
||||
"""
|
||||
try:
|
||||
# Open image
|
||||
img = Image.open(BytesIO(image_data))
|
||||
|
||||
# Convert RGBA to RGB if needed, ensuring no transparency
|
||||
if img.mode == "RGBA":
|
||||
# Check for transparency
|
||||
if img.getchannel("A").getextrema()[0] < 255:
|
||||
return self.create_text_message(
|
||||
"PNG image contains transparent or translucent pixels, which is not supported. "
|
||||
"Please provide an image without transparency."
|
||||
)
|
||||
# Convert to RGB
|
||||
img = img.convert("RGB")
|
||||
elif img.mode != "RGB":
|
||||
# Convert any other mode to RGB
|
||||
img = img.convert("RGB")
|
||||
|
||||
# Validate/adjust dimensions
|
||||
if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT):
|
||||
logger.warning(
|
||||
f"Image dimensions {img.size} do not match required dimensions "
|
||||
f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..."
|
||||
)
|
||||
img = img.resize(
|
||||
(NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
# Validate bit depth
|
||||
if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE:
|
||||
return self.create_text_message(
|
||||
f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel"
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(
|
||||
"Failed to process image. Please ensure the image is a valid JPEG or PNG file."
|
||||
)
|
||||
|
||||
def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]:
|
||||
"""Download and return image data from S3."""
|
||||
parsed_uri = urlparse(s3_uri)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
return response["Body"].read()
|
||||
|
||||
def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]:
|
||||
"""Start the async video generation process."""
|
||||
return bedrock.start_async_invoke(
|
||||
modelId=NOVA_REEL_MODEL_ID,
|
||||
modelInput=model_input,
|
||||
outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}},
|
||||
)
|
||||
|
||||
def _handle_generation_mode(
|
||||
self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool
|
||||
) -> ToolInvokeMessage:
|
||||
"""Handle async or sync video generation mode."""
|
||||
invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
|
||||
video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
|
||||
video_uri = f"{video_path}/output.mp4"
|
||||
|
||||
if async_mode:
|
||||
return self.create_text_message(
|
||||
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
|
||||
f"Video will be available at: {video_uri}"
|
||||
)
|
||||
|
||||
return self._wait_for_completion(bedrock, s3_client, invocation_arn)
|
||||
|
||||
def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage:
|
||||
"""Wait for video generation completion and handle the result."""
|
||||
while True:
|
||||
status_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
|
||||
status = status_response["status"]
|
||||
video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
|
||||
|
||||
if status == "Completed":
|
||||
return self._handle_completed_video(s3_client, video_path)
|
||||
elif status == "Failed":
|
||||
failure_message = status_response.get("failureMessage", "Unknown error")
|
||||
return self.create_text_message(f"Video generation failed.\nError: {failure_message}")
|
||||
elif status == "InProgress":
|
||||
time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL)
|
||||
else:
|
||||
return self.create_text_message(f"Unexpected status: {status}")
|
||||
|
||||
def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage:
|
||||
"""Handle completed video generation and return the result."""
|
||||
parsed_uri = urlparse(video_path)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/") + "/output.mp4"
|
||||
|
||||
try:
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
video_content = response["Body"].read()
|
||||
return [
|
||||
self.create_text_message(f"Video is available at: {video_path}/output.mp4"),
|
||||
self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"),
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading video: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(
|
||||
f"Video generation completed but failed to download video: {str(e)}\n"
|
||||
f"Video is available at: s3://{bucket}/{key}"
|
||||
)
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""Define the tool's runtime parameters."""
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述"
|
||||
),
|
||||
llm_description="Describe the video you want to generate",
|
||||
),
|
||||
ToolParameter(
|
||||
name="video_output_s3uri",
|
||||
label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="dimension",
|
||||
label=I18nObject(en_US="Dimension", zh_Hans="尺寸"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_DIMENSION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="duration",
|
||||
label=I18nObject(en_US="Duration", zh_Hans="时长"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_DURATION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="fps",
|
||||
label=I18nObject(en_US="FPS", zh_Hans="帧率"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_FPS,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="async",
|
||||
label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"),
|
||||
type=ToolParameter.ToolParameterType.BOOLEAN,
|
||||
required=False,
|
||||
default=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)",
|
||||
zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="aws_region",
|
||||
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_REGION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_input_s3uri",
|
||||
label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame",
|
||||
zh_Hans="用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
return parameters
|
||||
124
api/core/tools/provider/builtin/aws/tools/nova_reel.yaml
Normal file
124
api/core/tools/provider/builtin/aws/tools/nova_reel.yaml
Normal file
@ -0,0 +1,124 @@
|
||||
identity:
|
||||
name: nova_reel
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS Bedrock Nova Reel
|
||||
zh_Hans: AWS Bedrock Nova Reel
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
|
||||
zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
|
||||
llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution.
|
||||
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Text description of the video you want to generate
|
||||
zh_Hans: 您想要生成的视频的文本描述
|
||||
llm_description: Describe the video you want to generate
|
||||
form: llm
|
||||
|
||||
- name: video_output_s3uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Output S3 URI
|
||||
zh_Hans: 输出S3 URI
|
||||
human_description:
|
||||
en_US: S3 URI where the generated video will be stored
|
||||
zh_Hans: 生成的视频将存储的S3 URI
|
||||
form: form
|
||||
|
||||
- name: dimension
|
||||
type: string
|
||||
required: false
|
||||
default: 1280x720
|
||||
label:
|
||||
en_US: Dimension
|
||||
zh_Hans: 尺寸
|
||||
human_description:
|
||||
en_US: Video dimensions (width x height)
|
||||
zh_Hans: 视频尺寸(宽 x 高)
|
||||
form: form
|
||||
|
||||
- name: duration
|
||||
type: number
|
||||
required: false
|
||||
default: 6
|
||||
label:
|
||||
en_US: Duration
|
||||
zh_Hans: 时长
|
||||
human_description:
|
||||
en_US: Video duration in seconds
|
||||
zh_Hans: 视频时长(秒)
|
||||
form: form
|
||||
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
default: 0
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子值
|
||||
human_description:
|
||||
en_US: Random seed for video generation
|
||||
zh_Hans: 视频生成的随机种子
|
||||
form: form
|
||||
|
||||
- name: fps
|
||||
type: number
|
||||
required: false
|
||||
default: 24
|
||||
label:
|
||||
en_US: FPS
|
||||
zh_Hans: 帧率
|
||||
human_description:
|
||||
en_US: Frames per second for the generated video
|
||||
zh_Hans: 生成视频的每秒帧数
|
||||
form: form
|
||||
|
||||
- name: async
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
label:
|
||||
en_US: Async Mode
|
||||
zh_Hans: 异步模式
|
||||
human_description:
|
||||
en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion)
|
||||
zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成)
|
||||
form: llm
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
default: us-east-1
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: AWS region for Bedrock service
|
||||
zh_Hans: Bedrock 服务的 AWS 区域
|
||||
form: form
|
||||
|
||||
- name: image_input_s3uri
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Input Image S3 URI
|
||||
zh_Hans: 输入图像S3 URI
|
||||
human_description:
|
||||
en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame
|
||||
zh_Hans: 用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI
|
||||
form: llm
|
||||
|
||||
development:
|
||||
dependencies:
|
||||
- boto3
|
||||
- pillow
|
||||
80
api/core/tools/provider/builtin/aws/tools/s3_operator.py
Normal file
80
api/core/tools/provider/builtin/aws/tools/s3_operator.py
Normal file
@ -0,0 +1,80 @@
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class S3Operator(BuiltinTool):
|
||||
s3_client: Any = None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
# Initialize S3 client if not already done
|
||||
if not self.s3_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
else:
|
||||
self.s3_client = boto3.client("s3")
|
||||
|
||||
# Parse S3 URI
|
||||
s3_uri = tool_parameters.get("s3_uri")
|
||||
if not s3_uri:
|
||||
return self.create_text_message("s3_uri parameter is required")
|
||||
|
||||
parsed_uri = urlparse(s3_uri)
|
||||
if parsed_uri.scheme != "s3":
|
||||
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
|
||||
|
||||
bucket = parsed_uri.netloc
|
||||
# Remove leading slash from key
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
operation_type = tool_parameters.get("operation_type", "read")
|
||||
generate_presign_url = tool_parameters.get("generate_presign_url", False)
|
||||
presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour
|
||||
|
||||
if operation_type == "write":
|
||||
text_content = tool_parameters.get("text_content")
|
||||
if not text_content:
|
||||
return self.create_text_message("text_content parameter is required for write operation")
|
||||
|
||||
# Write content to S3
|
||||
self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8"))
|
||||
result = f"s3://{bucket}/{key}"
|
||||
|
||||
# Generate presigned URL for the written object if requested
|
||||
if generate_presign_url:
|
||||
result = self.s3_client.generate_presigned_url(
|
||||
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
|
||||
)
|
||||
|
||||
else: # read operation
|
||||
# Get object from S3
|
||||
response = self.s3_client.get_object(Bucket=bucket, Key=key)
|
||||
result = response["Body"].read().decode("utf-8")
|
||||
|
||||
# Generate presigned URL if requested
|
||||
if generate_presign_url:
|
||||
result = self.s3_client.generate_presigned_url(
|
||||
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
|
||||
)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except self.s3_client.exceptions.NoSuchBucket:
|
||||
return self.create_text_message(f"Bucket '{bucket}' does not exist")
|
||||
except self.s3_client.exceptions.NoSuchKey:
|
||||
return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'")
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception: {str(e)}")
|
||||
98
api/core/tools/provider/builtin/aws/tools/s3_operator.yaml
Normal file
98
api/core/tools/provider/builtin/aws/tools/s3_operator.yaml
Normal file
@ -0,0 +1,98 @@
|
||||
identity:
|
||||
name: s3_operator
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS S3 Operator
|
||||
zh_Hans: AWS S3 读写器
|
||||
pt_BR: AWS S3 Operator
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: AWS S3 Writer and Reader
|
||||
zh_Hans: 读写S3 bucket中的文件
|
||||
pt_BR: AWS S3 Writer and Reader
|
||||
llm: AWS S3 Writer and Reader
|
||||
parameters:
|
||||
- name: text_content
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: The text to write
|
||||
zh_Hans: 待写入的文本
|
||||
pt_BR: The text to write
|
||||
human_description:
|
||||
en_US: The text to write
|
||||
zh_Hans: 待写入的文本
|
||||
pt_BR: The text to write
|
||||
llm_description: The text to write
|
||||
form: llm
|
||||
- name: s3_uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: s3 uri
|
||||
zh_Hans: s3 uri
|
||||
pt_BR: s3 uri
|
||||
human_description:
|
||||
en_US: s3 uri
|
||||
zh_Hans: s3 uri
|
||||
pt_BR: s3 uri
|
||||
llm_description: s3 uri
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: region of bucket
|
||||
zh_Hans: bucket 所在的region
|
||||
pt_BR: region of bucket
|
||||
human_description:
|
||||
en_US: region of bucket
|
||||
zh_Hans: bucket 所在的region
|
||||
pt_BR: region of bucket
|
||||
llm_description: region of bucket
|
||||
form: form
|
||||
- name: operation_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: operation type
|
||||
zh_Hans: 操作类型
|
||||
pt_BR: operation type
|
||||
human_description:
|
||||
en_US: operation type
|
||||
zh_Hans: 操作类型
|
||||
pt_BR: operation type
|
||||
default: read
|
||||
options:
|
||||
- value: read
|
||||
label:
|
||||
en_US: read
|
||||
zh_Hans: 读
|
||||
- value: write
|
||||
label:
|
||||
en_US: write
|
||||
zh_Hans: 写
|
||||
form: form
|
||||
- name: generate_presign_url
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Generate presigned URL
|
||||
zh_Hans: 生成预签名URL
|
||||
human_description:
|
||||
en_US: Whether to generate a presigned URL for the S3 object
|
||||
zh_Hans: 是否生成S3对象的预签名URL
|
||||
default: false
|
||||
form: form
|
||||
- name: presign_expiry
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Presigned URL expiration time
|
||||
zh_Hans: 预签名URL有效期
|
||||
human_description:
|
||||
en_US: Expiration time in seconds for the presigned URL
|
||||
zh_Hans: 预签名URL的有效期(秒)
|
||||
default: 3600
|
||||
form: form
|
||||
@ -11,7 +11,10 @@ class ComfyUIProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
ws = websocket.WebSocket()
|
||||
base_url = URL(credentials.get("base_url"))
|
||||
ws_address = f"ws://{base_url.authority}/ws?clientId=test123"
|
||||
ws_protocol = "ws"
|
||||
if base_url.scheme == "https":
|
||||
ws_protocol = "wss"
|
||||
ws_address = f"{ws_protocol}://{base_url.authority}/ws?clientId=test123"
|
||||
|
||||
try:
|
||||
ws.connect(ws_address)
|
||||
|
||||
@ -40,7 +40,10 @@ class ComfyUiClient:
|
||||
def open_websocket_connection(self) -> tuple[WebSocket, str]:
|
||||
client_id = str(uuid.uuid4())
|
||||
ws = WebSocket()
|
||||
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
|
||||
ws_protocol = "ws"
|
||||
if self.base_url.scheme == "https":
|
||||
ws_protocol = "wss"
|
||||
ws_address = f"{ws_protocol}://{self.base_url.authority}/ws?clientId={client_id}"
|
||||
ws.connect(ws_address)
|
||||
return ws, client_id
|
||||
|
||||
|
||||
@ -1,32 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.file import FileTransferMethod, FileType
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from factories import file_factory
|
||||
|
||||
|
||||
class VectorizerProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
mapping = {
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"type": FileType.IMAGE,
|
||||
"id": "test_id",
|
||||
"url": "https://cloud.dify.ai/logo/logo-site.png",
|
||||
}
|
||||
test_img = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id="__test_123",
|
||||
)
|
||||
try:
|
||||
VectorizerTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={"mode": "test", "image": test_img},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
return
|
||||
|
||||
@ -210,7 +210,7 @@ class ApiTool(Tool):
|
||||
)
|
||||
return response
|
||||
else:
|
||||
raise ValueError(f"Invalid http method {self.method}")
|
||||
raise ValueError(f"Invalid http method {method}")
|
||||
|
||||
def _convert_body_property_any_of(
|
||||
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
|
||||
|
||||
@ -8,9 +8,10 @@ from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from httpx import get
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
@ -94,12 +95,11 @@ class ToolFileManager:
|
||||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = get(file_url)
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to download file from {file_url}")
|
||||
raise
|
||||
except httpx.TimeoutException as e:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = guess_type(file_url)[0] or "octet/stream"
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
|
||||
@ -21,6 +21,7 @@ from .variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
@ -43,6 +44,7 @@ __all__ = [
|
||||
"ArraySegment",
|
||||
"ArrayStringSegment",
|
||||
"ArrayStringVariable",
|
||||
"ArrayVariable",
|
||||
"FileSegment",
|
||||
"FileVariable",
|
||||
"FloatSegment",
|
||||
|
||||
@ -10,6 +10,7 @@ from .segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
@ -52,19 +53,23 @@ class ObjectVariable(ObjectSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayAnyVariable(ArrayAnySegment, Variable):
|
||||
class ArrayVariable(ArraySegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayStringVariable(ArrayStringSegment, Variable):
|
||||
class ArrayAnyVariable(ArrayAnySegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayNumberVariable(ArrayNumberSegment, Variable):
|
||||
class ArrayStringVariable(ArrayStringSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayObjectVariable(ArrayObjectSegment, Variable):
|
||||
class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -45,3 +45,6 @@ class NodeRunResult(BaseModel):
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
error_type: Optional[str] = None # error type if status is failed
|
||||
|
||||
# single step node run retry
|
||||
retry_index: int = 0
|
||||
|
||||
@ -33,7 +33,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
exceptions_count: Optional[int] = Field(description="exception count", default=0)
|
||||
exceptions_count: int = Field(description="exception count", default=0)
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
@ -97,6 +97,12 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
start_at: datetime = Field(..., description="retry start time")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
@ -170,7 +171,9 @@ class Graph(BaseModel):
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
parent_parallel_id=parallel.parent_parallel_id,
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
|
||||
@ -5,6 +5,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from copy import copy, deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@ -581,7 +583,7 @@ class GraphEngine:
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
@ -607,36 +609,120 @@ class GraphEngine:
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
retries = 0
|
||||
shoudl_continue_retry = True
|
||||
while shoudl_continue_retry and retries <= max_retries:
|
||||
try:
|
||||
# run node
|
||||
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if (
|
||||
retries == max_retries
|
||||
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||
and run_result.outputs
|
||||
and not node_instance.should_continue_on_error
|
||||
):
|
||||
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if node_instance.should_retry and retries < max_retries:
|
||||
retries += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=run_result.error,
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
continue
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
@ -645,21 +731,23 @@ class GraphEngine:
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
|
||||
parallel_start_node_id
|
||||
)
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
@ -670,108 +758,59 @@ class GraphEngine:
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
|
||||
@ -147,6 +147,8 @@ class AnswerStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
|
||||
if (
|
||||
|
||||
@ -60,7 +60,6 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
|
||||
@ -5,6 +6,8 @@ from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
@ -31,13 +34,22 @@ class StreamProcessor(ABC):
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
if finished_node_id not in self.graph.edge_mapping:
|
||||
logger.warning(f"node {finished_node_id} has no edge mapping")
|
||||
return
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (
|
||||
edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify
|
||||
):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
# remove unreachable nodes
|
||||
# FIXME: because of the code branch can combine directly, so for answer node
|
||||
# we remove the node maybe shortcut the answer node, so comment this code for now
|
||||
# there is not effect on the answer node and the workflow, when we have a better solution
|
||||
# we can open this code. Issues: #11542 #9560 #10638 #10564
|
||||
|
||||
# reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
|
||||
@ -106,12 +106,25 @@ class DefaultValue(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
version: str = "1"
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
class BaseNodeError(Exception):
|
||||
class BaseNodeError(ValueError):
|
||||
"""Base class for node errors."""
|
||||
|
||||
pass
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
@ -147,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
bool: if should continue on error
|
||||
"""
|
||||
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
||||
@property
|
||||
def should_retry(self) -> bool:
|
||||
"""judge if should retry
|
||||
|
||||
Returns:
|
||||
bool: if should retry
|
||||
"""
|
||||
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result, self.node_data.outputs)
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _check_string(self, value: str, variable: str) -> str:
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, str):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise OutputValidationError(
|
||||
@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, int | float):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise OutputValidationError(
|
||||
@ -118,14 +116,12 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
return value
|
||||
|
||||
def _transform_result(
|
||||
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
|
||||
) -> dict:
|
||||
"""
|
||||
Transform result
|
||||
:param result: result
|
||||
:param output_schema: output schema
|
||||
:return:
|
||||
"""
|
||||
self,
|
||||
result: Mapping[str, Any],
|
||||
output_schema: Optional[dict[str, CodeNodeData.Output]],
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
@ -8,12 +9,6 @@ import docx
|
||||
import pandas as pd
|
||||
import pypdfium2 # type: ignore
|
||||
import yaml # type: ignore
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.email import partition_email
|
||||
from unstructured.partition.epub import partition_epub
|
||||
from unstructured.partition.msg import partition_msg
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
@ -28,6 +23,8 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
"""
|
||||
@ -183,10 +180,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOC/DOCX file.
|
||||
For now support only paragraph and table add more if needed
|
||||
"""
|
||||
try:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
text = []
|
||||
# Process paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text.append(paragraph.text)
|
||||
|
||||
# Process tables
|
||||
for table in doc.tables:
|
||||
# Table header
|
||||
try:
|
||||
# table maybe cause errors so ignore it.
|
||||
if len(table.rows) > 0 and table.rows[0].cells is not None:
|
||||
# Check if any cell in the table has text
|
||||
has_content = False
|
||||
for row in table.rows:
|
||||
if any(cell.text.strip() for cell in row.cells):
|
||||
has_content = True
|
||||
break
|
||||
|
||||
if has_content:
|
||||
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
|
||||
for row in table.rows[1:]:
|
||||
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
|
||||
@ -256,6 +286,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_ppt(file=file)
|
||||
@ -265,6 +297,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
|
||||
@ -287,6 +322,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_epub(file=file)
|
||||
@ -296,6 +333,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_eml(file_content: bytes) -> str:
|
||||
from unstructured.partition.email import partition_email
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_email(file=file)
|
||||
@ -305,6 +344,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
|
||||
|
||||
|
||||
def _extract_text_from_msg(file_content: bytes) -> str:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
elements = partition_msg(file=file)
|
||||
|
||||
@ -135,6 +135,8 @@ class EndStreamGeneratorRouter:
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id not in node_id_config_mapping:
|
||||
continue
|
||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||
if source_node_type in {
|
||||
NodeType.IF_ELSE.value,
|
||||
|
||||
@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum):
|
||||
|
||||
|
||||
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
||||
@ -1,4 +1,10 @@
|
||||
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from .event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
RunCompletedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from .types import NodeEvent
|
||||
|
||||
__all__ = [
|
||||
@ -6,5 +12,6 @@ __all__ = [
|
||||
"NodeEvent",
|
||||
"RunCompletedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunRetryEvent",
|
||||
"RunStreamChunkEvent",
|
||||
]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user