Compare commits

..

76 Commits

Author SHA1 Message Date
e88ea71aef chore/bump version to 0.14.2 (#12017)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 19:15:48 +08:00
e0f1410b48 fix: issue Multiple Paths Between IF/ELSE Branches (#11646)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-23 18:56:59 +08:00
c3c85276d1 Fix/refactor invoke result handling in question classifier node (#12015)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 17:54:08 +08:00
af2888d394 fix: remove json_schema if response format is disabled. (#12014)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 17:53:57 +08:00
d0dd8b7955 fix: add UUID validation for tool file ID extraction (#12011)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 17:53:42 +08:00
75bce2822e fix: add logging for missing edge mapping in StreamProcessor (#12008)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 17:53:36 +08:00
425cc1ea85 Fix/add retry mechanism to billing service request handling (#12006)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 17:53:25 +08:00
2bf33c4dd2 Fix/workflow retry log (#12013) 2024-12-23 17:18:09 +08:00
dc19cd5d9d fix: add retry feature to code node (#12005)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-23 16:42:28 +08:00
dfc25dbdd0 fix: drop useless and wrong code in Account (#11961)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-23 16:30:04 +08:00
c4091c4c66 fix: improve error handling for file retrieval in AwsS3Storage (#12002)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 16:28:54 +08:00
ef95b1268e Fix/workflow retry (#11999) 2024-12-23 15:55:50 +08:00
e068bbec73 feat: add RequestBodyError for invalid request body handling (#11994)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 15:53:03 +08:00
9cfd1c67b6 fix: Introduce ArrayVariable and update iteration node to handle it (#12001)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 15:52:50 +08:00
8978a6a3ff fix: remove unused credential validation logic in VectorizerProvider (#12000)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 15:30:58 +08:00
4e3d732934 feat: Warning on invite modal when mail setup is incomplete (#11809) 2024-12-23 15:27:49 +08:00
03548cdfbc fix: handle broader request exceptions in OAuth process (#11997)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-23 15:23:11 +08:00
70dd69d533 fix: Multiple Paths Between IF/ELSE Branches Invalidate Conditions (#11544)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-23 15:09:02 +08:00
453f324f54 fix: retry node in iteration logs wrong (#11995)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-23 15:06:01 +08:00
c1aa55f3ea fix: remove the unused retry index field (#11903)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-12-23 14:32:11 +08:00
4b1e13e982 Fix 11979 (#11984) 2024-12-23 14:30:04 +08:00
4584eb3058 Add custom to file types (#11966)
Co-authored-by: yagiyuki <yagiyuki>
2024-12-23 13:53:46 +08:00
02a7ae15f9 fix: fix update external dataset error in dataset list (#11989) 2024-12-23 13:53:05 +08:00
74b1b60125 Feat: account page dark mode (#11977) 2024-12-23 11:17:49 +08:00
39df994ff9 fix: create_feedback args are wrong (#11962)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-23 09:20:47 +08:00
d9875fe232 fix(commands): validate name encoding for non-Latin characters (#11965) 2024-12-23 09:20:30 +08:00
26c10b9931 fix: 'dict_keys' object is not subscriptable error (#11957)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-22 14:58:33 +08:00
750662eb08 fix(workflow): update updated_at default to use UTC timezone (#11960)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 14:55:18 +08:00
6b49889041 fix: messagefeedbackapi support content (#11716)
Signed-off-by: weiyang <24080293@smb956101.com>
Co-authored-by: weiyang <24080293@smb956101.com>
2024-12-22 10:45:55 +08:00
03ddee3663 fix(variable_assigner): change VariableOperatorNodeError to inherit from ValueError (#11951)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:43:40 +08:00
10caab1729 fix: change CredentialsValidateFailedError to inherit from ValueError (#11950)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:43:31 +08:00
c6a72def88 fix(ops_trace_manager): handle None workflow_run in workflow_trace method and raise ValueError (#11953)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:42:51 +08:00
21a31d7f8b fix(base_node): change BaseNodeError exception type from Exception to ValueError (#11952)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:42:30 +08:00
2c4df108e5 fix: raise http request node error on httpx.request error (#11954)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:41:53 +08:00
5db8addcc6 fix(core/errors): change base class of custom exceptions to ValueError (#11955)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:41:34 +08:00
dd0e81d094 fix: enhance type hints and improve audio message handling in TTS pub… (#11947)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:41:06 +08:00
90f093eb67 fix(json_in_md_parser): improve error messages for JSON parsing failures (#11948)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:40:56 +08:00
a056a9d601 feat(code_node): add more check (#11949)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:40:43 +08:00
2ad2a402fb fix(app_dsl_service): handle missing app mode with a ValueError (#11945)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:40:12 +08:00
3d07a94bd7 fix: refactor conversation pagination to use SQLAlchemy session manag… (#11956)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-22 10:39:29 +08:00
366857cd26 fix: gemini system prompt with variable raise error (#11946) 2024-12-21 23:14:05 +08:00
9578246bbb fix: The default updated_at when a workflow is created (#11709)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2024-12-21 23:13:58 +08:00
9ee9e9c6de fix: self.method should method in api_tool.py (#11926)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-21 21:24:59 +08:00
e22cc28114 fix:log error(#11942) (#11943) 2024-12-21 21:24:33 +08:00
a227af3664 fix(code_node): update type hints for string and number checks in Cod… (#11936)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:24:22 +08:00
599d410d99 fix: validate reranking model attributes before processing (#11930)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:23:12 +08:00
5e37ab60d8 fix: validate response type in transform_response method (#11931)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:23:03 +08:00
0b06235527 fix: add RemoteFileUploadError for better error handling in remote fi… (#11933)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:22:57 +08:00
b8d42cdea7 fix: change MaxRetriesExceededError to inherit from ValueError (#11934)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:22:47 +08:00
455791b710 fix(model_runtime): make invoke as ValueError (#11929)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:22:14 +08:00
90323cd355 fix(tool_file_manager): raise ValueError when get timeout (#11928)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:22:06 +08:00
c07d9e96ce fix(nodes): handle errors in question_classifier and parameter_extractor (#11927)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:21:57 +08:00
810adb8a94 fix: change OutputParserError to inherit from ValueError (#11935)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:21:30 +08:00
606aadb891 refactor: update builtin tool provider methods to use session management (#11938)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:21:09 +08:00
8f73670925 fix(file_factory): validate upload_file_id before querying UploadFile (#11937)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:21:00 +08:00
8c559d6231 fix(retrieval_service): avoid to use exception (#11925)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-21 21:19:46 +08:00
786cb6859b fix: ensure WorkflowRun attributes are refreshed in WorkflowCycleMana… (#11913) 2024-12-21 15:05:04 +08:00
de8800f41a add three aws tools (#11905)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
2024-12-21 13:36:13 +08:00
7a00798027 Update input-var-list.tsx (#9987) 2024-12-21 11:54:12 +08:00
6ded06c6d9 fix: dataset search-input compostion can't work in chrome (#11907)
Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com>
2024-12-21 11:26:17 +08:00
f53741c5b9 revert: these 2 settings (#11906)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-12-20 18:24:47 +08:00
2681bafb76 fix: handle document fetching from URL in Anthropic LLM model, solving base64 decoding error (#11858) 2024-12-20 18:23:42 +08:00
ac635c70cd fix: doc can not extract tables (#11879)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Co-authored-by: akinobu-i <akinobu-i@users.noreply.github.com>
2024-12-20 17:19:46 +08:00
ef7e47d162 fix: rerank switch (#11897) 2024-12-20 16:12:34 +08:00
4211b9abbd chore: translate i18n files (#11892)
Co-authored-by: zxhlyh <16177003+zxhlyh@users.noreply.github.com>
2024-12-20 16:12:01 +08:00
0c0120ef27 Feat/workflow retry (#11885) 2024-12-20 15:44:37 +08:00
dacd457478 feat: add workflow parallel depth limit configuration (#11460)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
2024-12-20 14:52:20 +08:00
7b03a0316d fix: better memory usage from 800+ to 500+ (#11796)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-20 14:51:43 +08:00
52201d95b1 chore: add retry index migration (#11887)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-20 14:40:33 +08:00
e2cde628bb chore: translate i18n files (#11855)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-12-20 14:19:47 +08:00
3335fa78fc fix: node 22 build (#11883) 2024-12-20 14:14:27 +08:00
7abc7fa573 Feat: Retry on node execution errors (#11871)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-20 14:14:06 +08:00
f6247fe67c Feat: Add partial success status to the app log (#11869)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-20 14:13:44 +08:00
996a9135f6 feat(llm_node): support order in text and files (#11837)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-20 14:12:50 +08:00
3599751f93 chore(db): use a better way to export models and remove unused table (#11838)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-20 14:12:29 +08:00
2d186e1e76 chore(deps): bump nanoid from 3.3.7 to 3.3.8 in /web (#11876)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-20 13:46:54 +08:00
260 changed files with 4254 additions and 1209 deletions

View File

@ -399,6 +399,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# App configuration

View File

@ -555,7 +555,8 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
if language not in languages:
language = "en-US"
name = name.strip()
# Validates name encoding for non-Latin characters.
name = name.strip().encode("utf-8").decode("utf-8") if name else None
# generate random password
new_password = secrets.token_urlsafe(16)

View File

@ -433,6 +433,11 @@ class WorkflowConfig(BaseSettings):
default=5,
)
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
description="Maximum allowed depth for nested parallel executions",
default=3,
)
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
default=200 * 1024,

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.14.1",
default="0.14.2",
)
COMMIT_SHA: str = Field(

View File

@ -4,3 +4,8 @@ from werkzeug.exceptions import HTTPException
class FilenameNotExistsError(HTTPException):
code = 400
description = "The specified filename does not exist."
class RemoteFileUploadError(HTTPException):
code = 400
description = "Error uploading remote file."

View File

@ -6,6 +6,7 @@ from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from configs import dify_config
from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
@ -426,7 +427,21 @@ class ConvertToWorkflowApi(Resource):
}
class WorkflowConfigApi(Resource):
"""Resource for workflow configuration."""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")

View File

@ -5,8 +5,7 @@ from typing import Optional, Union
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models import App
from models.model import AppMode
from models import App, AppMode
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):

View File

@ -76,7 +76,7 @@ class OAuthCallback(Resource):
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.HTTPError as e:
except requests.exceptions.RequestException as e:
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
return {"error": "OAuth process failed"}, 400

View File

@ -1,12 +1,14 @@
from flask_login import current_user
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource):
pinned = True if args["pinned"] == "true" else False
try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
user=current_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
)
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=current_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -70,7 +70,7 @@ class MessageFeedbackApi(InstalledAppResource):
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -7,6 +7,7 @@ from flask_restful import Resource, marshal_with, reqparse
import services
from controllers.common import helpers
from controllers.common.errors import RemoteFileUploadError
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
@ -43,10 +44,14 @@ class RemoteFileUploadApi(Resource):
url = args["url"]
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp.raise_for_status()
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp)

View File

@ -3,12 +3,14 @@ import io
from flask import send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
from services.tools.api_tools_manage_service import ApiToolManageService
@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource):
args = parser.parse_args()
return BuiltinToolManageService.update_builtin_tool_provider(
user_id,
tenant_id,
provider,
args["credentials"],
)
with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
credentials=args["credentials"],
)
session.commit()
return result
class ToolBuiltinProviderGetCredentialsApi(Resource):
@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
user_id,
tenant_id,
provider,
tenant_id=tenant_id,
provider_name=provider,
)

View File

@ -1,5 +1,6 @@
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
import services
@ -7,6 +8,7 @@ from controllers.service_api import api
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
conversation_delete_fields,
conversation_infinite_scroll_pagination_fields,
@ -39,14 +41,16 @@ class ConversationApi(Resource):
args = parser.parse_args()
try:
return ConversationService.pagination_by_last_id(
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"],
)
with Session(db.engine) as session:
return ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"],
)
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -104,10 +104,11 @@ class MessageFeedbackApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
parser.add_argument("content", type=str, location="json")
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -1,11 +1,13 @@
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.web import api
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource):
pinned = True if args["pinned"] == "true" else False
try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args["sort_by"],
)
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args["sort_by"],
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -108,7 +108,7 @@ class MessageFeedbackApi(WebApiResource):
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -5,6 +5,7 @@ from flask_restful import marshal_with, reqparse
import services
from controllers.common import helpers
from controllers.common.errors import RemoteFileUploadError
from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
@ -38,10 +39,14 @@ class RemoteFileUploadApi(WebApiResource):
url = args["url"]
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3)
resp.raise_for_status()
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp)

View File

@ -4,14 +4,17 @@ import logging
import queue
import re
import threading
from collections.abc import Iterable
from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentMessageEvent,
QueueLLMChunkEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
WorkflowQueueMessage,
)
from core.model_manager import ModelManager
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -21,7 +24,7 @@ class AudioTrunk:
self.status = status
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
)
def _process_future(future_queue, audio_queue):
def _process_future(
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
audio_queue: queue.Queue[AudioTrunk],
):
while True:
try:
future = future_queue.get()
if future is None:
break
for audio in future.result():
invoke_result = future.result()
if not invoke_result:
continue
for audio in invoke_result:
audio_base64 = base64.b64encode(bytes(audio))
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
except Exception as e:
@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ""
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
self._runtime_thread = threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message):
try:
self._msg_queue.put(message)
except Exception as e:
self.logger.warning(e)
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
self._msg_queue.put(message)
def _runtime(self):
future_queue = queue.Queue()
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
while True:
try:
@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
break
future_queue.put(None)
def check_and_get_audio(self) -> AudioTrunk | None:
def check_and_get_audio(self):
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:

View File

@ -180,7 +180,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else:
continue
raise Exception("Queue listening stopped unexpectedly.")
raise ValueError("queue listening stopped unexpectedly.")
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
@ -197,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream_response=stream_response,
)
def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
@ -222,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -291,9 +291,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(
event,
QueueNodeRetryEvent,
):
if not workflow_run:
raise ValueError("workflow run not initialized.")
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
@ -331,63 +349,48 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if response:
yield response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
raise ValueError("workflow run not initialized.")
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
@ -406,10 +409,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
@ -429,10 +432,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
@ -511,7 +514,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(
@ -522,7 +525,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
raise ValueError("graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:

View File

@ -1,7 +1,6 @@
import queue
import time
from abc import abstractmethod
from collections.abc import Generator
from enum import Enum
from typing import Any
@ -11,9 +10,11 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueErrorEvent,
QueuePingEvent,
QueueStopEvent,
WorkflowQueueMessage,
)
from extensions.ext_redis import redis_client
@ -37,11 +38,11 @@ class AppQueueManager:
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
q = queue.Queue()
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q
def listen(self) -> Generator:
def listen(self):
"""
Listen to queue
:return:

View File

@ -155,7 +155,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
else:
continue
raise Exception("Queue listening stopped unexpectedly.")
raise ValueError("queue listening stopped unexpectedly.")
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
@ -171,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
@ -196,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -218,7 +218,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
break
else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
except Exception:
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
break
if tts_publisher:
@ -254,9 +254,27 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(
event,
QueueNodeRetryEvent,
):
if not workflow_run:
raise ValueError("workflow run not initialized.")
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
@ -289,64 +307,48 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
if node_failed_response:
yield node_failed_response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
@ -366,10 +368,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
@ -390,10 +392,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
@ -421,7 +423,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(

View File

@ -188,6 +188,41 @@ class WorkflowBasedAppRunner(AppRunner):
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.route_node_state.node_run_result
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
parallel_mode_run_id=event.parallel_mode_run_id,
inputs=inputs,
process_data=process_data,
outputs=outputs,
error=event.error,
execution_metadata=execution_metadata,
retry_index=event.retry_index,
)
)
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
@ -207,6 +242,17 @@ class WorkflowBasedAppRunner(AppRunner):
)
)
elif isinstance(event, NodeRunSucceededEvent):
node_run_result = event.route_node_state.node_run_result
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
@ -218,18 +264,10 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata,
in_iteration_id=event.in_iteration_id,
)
)
@ -422,36 +460,6 @@ class WorkflowBasedAppRunner(AppRunner):
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
)
)
elif isinstance(event, NodeRunRetryEvent):
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.error,
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
start_index=event.start_index,
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
@ -85,9 +86,9 @@ class QueueIterationStartEvent(AppQueueEvent):
start_at: datetime
node_run_index: int
inputs: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[dict[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
class QueueIterationNextEvent(AppQueueEvent):
@ -139,9 +140,9 @@ class QueueIterationCompletedEvent(AppQueueEvent):
start_at: datetime
node_run_index: int
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: Optional[str] = None
@ -304,9 +305,9 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: Optional[str] = None
@ -314,35 +315,18 @@ class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None
class QueueNodeRetryEvent(AppQueueEvent):
class QueueNodeRetryEvent(QueueNodeStartedEvent):
"""QueueNodeRetryEvent entity"""
event: QueueEvent = QueueEvent.RETRY
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: str
retry_index: int # retry index
start_index: int # start index
class QueueNodeInIterationFailedEvent(AppQueueEvent):
@ -368,10 +352,10 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: str
@ -399,10 +383,10 @@ class QueueNodeExceptionEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: str
@ -430,10 +414,10 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: str

View File

@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response,
)
def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if publisher is None:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None

View File

@ -438,6 +438,16 @@ class WorkflowCycleManage:
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
origin_metadata = {
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
}
merged_metadata = (
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
if event.execution_metadata is not None
else origin_metadata
)
execution_metadata = json.dumps(merged_metadata)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
@ -445,6 +455,7 @@ class WorkflowCycleManage:
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
@ -458,12 +469,8 @@ class WorkflowCycleManage:
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.index = event.start_index
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index
db.session.add(workflow_node_execution)
db.session.commit()
@ -505,6 +512,12 @@ class WorkflowCycleManage:
:param workflow_run: workflow run
:return:
"""
# Attach WorkflowRun to an active session so "created_by_role" can be accessed.
workflow_run = db.session.merge(workflow_run)
# Refresh to ensure any expired attributes are fully loaded
db.session.refresh(workflow_run)
created_by = None
if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value:
created_by_account = workflow_run.created_by_account

View File

@ -1,7 +1,7 @@
from typing import Optional
class LLMError(Exception):
class LLMError(ValueError):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError):
description = "Bad Request"
class ProviderTokenNotInitError(Exception):
class ProviderTokenNotInitError(ValueError):
"""
Custom exception raised when the provider token is not initialized.
"""
@ -27,7 +27,7 @@ class ProviderTokenNotInitError(Exception):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception):
class QuotaExceededError(ValueError):
"""
Custom exception raised when the quota for a provider has been exceeded.
"""
@ -35,7 +35,7 @@ class QuotaExceededError(Exception):
description = "Quota Exceeded"
class AppInvokeQuotaExceededError(Exception):
class AppInvokeQuotaExceededError(ValueError):
"""
Custom exception raised when the quota for an app has been exceeded.
"""
@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception):
description = "App Invoke Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception):
class ModelCurrentlyNotSupportError(ValueError):
"""
Custom exception raised when the model not support
"""
@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception):
description = "Model Currently Not Support"
class InvokeRateLimitError(Exception):
class InvokeRateLimitError(ValueError):
"""Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error"

View File

@ -1,15 +1,14 @@
import base64
from configs import dify_config
from core.file import file_repository
from core.helper import ssrf_proxy
from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
MultiModalPromptMessageContent,
VideoPromptMessageContent,
)
from extensions.ext_database import db
from extensions.ext_storage import storage
from . import helpers
@ -41,7 +40,7 @@ def to_prompt_message_content(
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
) -> MultiModalPromptMessageContent:
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
@ -70,16 +69,13 @@ def to_prompt_message_content(
def download(f: File, /):
if f.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
return _download_file_content(tool_file.file_key)
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
return _download_file_content(upload_file.key)
# remote file
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
return _download_file_content(f._storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
def _download_file_content(path: str, /):
@ -110,11 +106,9 @@ def _get_encoded_string(f: File, /):
response.raise_for_status()
data = response.content
case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key)
data = _download_file_content(f._storage_key)
case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key)
data = _download_file_content(f._storage_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string

View File

@ -1,32 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import ToolFile, UploadFile
from .models import File
def get_upload_file(*, session: Session, file: File):
if file.related_id is None:
raise ValueError("Missing file related_id")
stmt = select(UploadFile).filter(
UploadFile.id == file.related_id,
UploadFile.tenant_id == file.tenant_id,
)
record = session.scalar(stmt)
if not record:
raise ValueError(f"upload file {file.related_id} not found")
return record
def get_tool_file(*, session: Session, file: File):
if file.related_id is None:
raise ValueError("Missing file related_id")
stmt = select(ToolFile).filter(
ToolFile.id == file.related_id,
ToolFile.tenant_id == file.tenant_id,
)
record = session.scalar(stmt)
if not record:
raise ValueError(f"tool file {file.related_id} not found")
return record

View File

@ -47,6 +47,38 @@ class File(BaseModel):
mime_type: Optional[str] = None
size: int = -1
# Those properties are private, should not be exposed to the outside.
_storage_key: str
def __init__(
self,
*,
id: Optional[str] = None,
tenant_id: str,
type: FileType,
transfer_method: FileTransferMethod,
remote_url: Optional[str] = None,
related_id: Optional[str] = None,
filename: Optional[str] = None,
extension: Optional[str] = None,
mime_type: Optional[str] = None,
size: int = -1,
storage_key: str,
):
super().__init__(
id=id,
tenant_id=tenant_id,
type=type,
transfer_method=transfer_method,
remote_url=remote_url,
related_id=related_id,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
)
self._storage_key = storage_key
def to_dict(self) -> Mapping[str, str | int | None]:
data = self.model_dump(mode="json")
return {

View File

@ -118,7 +118,7 @@ class CodeExecutor:
return response.data.stdout or ""
@classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict:
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]):
"""
Execute code
:param language: code language

View File

@ -25,7 +25,7 @@ class TemplateTransformer(ABC):
return runner_script, preload_script
@classmethod
def extract_result_str_from_response(cls, response: str) -> str:
def extract_result_str_from_response(cls, response: str):
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
if not result:
raise ValueError("Failed to parse result")
@ -33,13 +33,21 @@ class TemplateTransformer(ABC):
return result
@classmethod
def transform_response(cls, response: str) -> dict:
def transform_response(cls, response: str) -> Mapping[str, Any]:
"""
Transform response to dict
:param response: response
:return:
"""
return json.loads(cls.extract_result_str_from_response(response))
try:
result = json.loads(cls.extract_result_str_from_response(response))
except json.JSONDecodeError:
raise ValueError("failed to parse response")
if not isinstance(result, dict):
raise ValueError("result must be a dict")
if not all(isinstance(k, str) for k in result):
raise ValueError("result keys must be strings")
return result
@classmethod
@abstractmethod

View File

@ -1,6 +1,5 @@
import base64
from extensions.ext_database import db
from libs import rsa
@ -14,6 +13,7 @@ def obfuscated_token(token: str):
def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
from models.engine import db
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")

View File

@ -24,7 +24,7 @@ BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
class MaxRetriesExceededError(Exception):
class MaxRetriesExceededError(ValueError):
"""Raised when the maximum number of retries is exceeded."""
pass
@ -45,6 +45,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
)
retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:
@ -64,11 +65,12 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
except httpx.RequestError as e:
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
if max_retries == 0:
raise
retries += 1
if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")

View File

@ -1,2 +1,2 @@
class OutputParserError(Exception):
class OutputParserError(ValueError):
pass

View File

@ -4,6 +4,7 @@ from .message_entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
@ -27,6 +28,7 @@ __all__ = [
"LLMResultChunkDelta",
"LLMUsage",
"ModelPropertyKey",
"MultiModalPromptMessageContent",
"PromptMessage",
"PromptMessage",
"PromptMessageContent",

View File

@ -84,10 +84,10 @@ class MultiModalPromptMessageContent(PromptMessageContent):
"""
type: PromptMessageContentType
format: str = Field(..., description="the format of multi-modal file")
base64_data: str = Field("", description="the base64 data of multi-modal file")
url: str = Field("", description="the url of multi-modal file")
mime_type: str = Field(..., description="the mime type of multi-modal file")
format: str = Field(default=..., description="the format of multi-modal file")
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file")
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
@computed_field(return_type=str)
@property

View File

@ -1,7 +1,7 @@
from typing import Optional
class InvokeError(Exception):
class InvokeError(ValueError):
"""Base class for all LLM exceptions."""
description: Optional[str] = None

View File

@ -1,4 +1,4 @@
class CredentialsValidateFailedError(Exception):
class CredentialsValidateFailedError(ValueError):
"""
Credentials validate failed error
"""

View File

@ -531,7 +531,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"source": {
"type": "base64",
"media_type": message_content.mime_type,
"data": message_content.data,
"data": message_content.base64_data,
},
}
sub_messages.append(sub_message_dict)

View File

@ -21,6 +21,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
@ -143,7 +144,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
try:
ping_message = SystemPromptMessage(content="ping")
ping_message = UserPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
except Exception as ex:
@ -187,17 +188,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
config_kwargs["stop_sequences"] = stop
genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)
history = []
system_instruction = None
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
elif content["role"] == "system":
system_instruction = content["parts"][0]
else:
history.append(content)
if not history:
raise InvokeError("The user prompt message is required. You only add a system prompt message.")
google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction)
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(**config_kwargs),
@ -404,7 +411,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
)
return glm_content
elif isinstance(message, SystemPromptMessage):
return {"role": "user", "parts": [to_part(message.content)]}
if isinstance(message.content, list):
text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
message.content = "".join(c.data for c in text_contents)
return {"role": "system", "parts": [to_part(message.content)]}
elif isinstance(message, ToolPromptMessage):
return {
"role": "function",

View File

@ -421,7 +421,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
# text completion model
response = client.completions.create(
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
prompt=prompt_messages[0].content,
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs,
)
if stream:
@ -593,6 +597,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:
model_parameters["response_format"] = {"type": response_format}
elif "json_schema" in model_parameters:
del model_parameters["json_schema"]
extra_model_kwargs = {}

View File

@ -4,11 +4,10 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import TYPE_CHECKING, Optional, Union, cast
import google.auth.transport.requests
import requests
import vertexai.generative_models as glm
from anthropic import AnthropicVertex, Stream
from anthropic.types import (
ContentBlockDeltaEvent,
@ -19,8 +18,6 @@ from anthropic.types import (
MessageStreamEvent,
)
from google.api_core import exceptions
from google.cloud import aiplatform
from google.oauth2 import service_account
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -47,6 +44,9 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
if TYPE_CHECKING:
import vertexai.generative_models as glm
logger = logging.getLogger(__name__)
@ -102,6 +102,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
from google.oauth2 import service_account
# use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python
service_account_key = credentials.get("vertex_service_account_key", "")
@ -406,13 +408,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return text.rstrip()
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool":
"""
Convert tool messages to glm tools
:param tools: tool messages
:return: glm tools
"""
import vertexai.generative_models as glm
return glm.Tool(
function_declarations=[
glm.FunctionDeclaration(
@ -473,6 +477,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
import vertexai.generative_models as glm
from google.cloud import aiplatform
from google.oauth2 import service_account
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
@ -522,7 +530,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
@ -554,7 +562,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
@ -638,13 +646,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return message_text
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content":
"""
Format a single message into glm.Content for Google API
:param message: one PromptMessage
:return: glm Content representation of message
"""
import vertexai.generative_models as glm
if isinstance(message, UserPromptMessage):
glm_content = glm.Content(role="user", parts=[])

View File

@ -2,12 +2,9 @@ import base64
import json
import time
from decimal import Decimal
from typing import Optional
from typing import TYPE_CHECKING, Optional
import tiktoken
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
@ -24,6 +21,11 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
if TYPE_CHECKING:
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
else:
VertexTextEmbeddingModel = None
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
"""
@ -48,6 +50,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param input_type: input type
:return: embeddings result
"""
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
@ -100,6 +106,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param credentials: model credentials
:return:
"""
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
try:
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]

View File

@ -355,7 +355,13 @@ class TraceTask:
def conversation_trace(self, **kwargs):
return kwargs
def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id):
def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id):
if not workflow_run:
raise ValueError("Workflow run not found")
db.session.merge(workflow_run)
db.sessoin.refresh(workflow_run)
workflow_id = workflow_run.workflow_id
tenant_id = workflow_run.tenant_id
workflow_run_id = workflow_run.id

View File

@ -83,11 +83,15 @@ class DataPostProcessor:
if reranking_model:
try:
model_manager = ModelManager()
reranking_provider_name = reranking_model.get("reranking_provider_name")
reranking_model_name = reranking_model.get("reranking_model_name")
if not reranking_provider_name or not reranking_model_name:
return None
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model["reranking_provider_name"],
provider=reranking_provider_name,
model_type=ModelType.RERANK,
model=reranking_model["reranking_model_name"],
model=reranking_model_name,
)
return rerank_model_instance
except InvokeAuthorizationError:

View File

@ -1,18 +1,19 @@
import re
from typing import Optional
import jieba
from jieba.analyse import default_tfidf
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
import jieba.analyse
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
jieba.analyse.default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
import jieba
keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
@ -22,6 +23,8 @@ class JiebaKeywordTableHandler:
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
results = set()
for token in tokens:
results.add(token)

View File

@ -103,7 +103,7 @@ class RetrievalService:
if exceptions:
exception_message = ";\n".join(exceptions)
raise Exception(exception_message)
raise ValueError(exception_message)
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(

View File

@ -6,10 +6,8 @@ from contextlib import contextmanager
from typing import Any
import jieba.posseg as pseg
import nltk
import numpy
import oracledb
from nltk.corpus import stopwords
from pydantic import BaseModel, model_validator
from configs import dify_config
@ -202,6 +200,10 @@ class OracleVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# lazy import
import nltk
from nltk.corpus import stopwords
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0)

View File

@ -62,7 +62,7 @@ class AppToolProviderEntity(ToolProviderController):
user_input_form_list = app_model_config.user_input_form_list
for input_form in user_input_form_list:
# get type
form_type = input_form.keys()[0]
form_type = list(input_form.keys())[0]
default = input_form[form_type]["default"]
required = input_form[form_type]["required"]
label = input_form[form_type]["label"]

View File

@ -0,0 +1,115 @@
import json
import operator
from typing import Any, Optional, Union
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class BedrockRetrieveTool(BuiltinTool):
bedrock_client: Any = None
knowledge_base_id: str = None
topk: int = None
def _bedrock_retrieve(
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
):
try:
retrieval_query = {"text": query_input}
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
# 如果有元数据过滤条件,则添加到检索配置中
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
response = self.bedrock_client.retrieve(
knowledgeBaseId=knowledge_base_id,
retrievalQuery=retrieval_query,
retrievalConfiguration=retrieval_configuration,
)
results = []
for result in response.get("retrievalResults", []):
results.append(
{
"content": result.get("content", {}).get("text", ""),
"score": result.get("score", 0.0),
"metadata": result.get("metadata", {}),
}
)
return results
except Exception as e:
raise Exception(f"Error retrieving from knowledge base: {str(e)}")
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
line = 0
try:
if not self.bedrock_client:
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
else:
self.bedrock_client = boto3.client("bedrock-agent-runtime")
line = 1
if not self.knowledge_base_id:
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
if not self.knowledge_base_id:
return self.create_text_message("Please provide knowledge_base_id")
line = 2
if not self.topk:
self.topk = tool_parameters.get("topk", 5)
line = 3
query = tool_parameters.get("query", "")
if not query:
return self.create_text_message("Please input query")
# 获取元数据过滤条件(如果存在)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
line = 4
retrieved_docs = self._bedrock_retrieve(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
)
line = 5
# Sort results by score in descending order
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
line = 6
return [self.create_json_message(res) for res in sorted_docs]
except Exception as e:
return self.create_text_message(f"Exception {str(e)}, line : {line}")
def validate_parameters(self, parameters: dict[str, Any]) -> None:
"""
Validate the parameters
"""
if not parameters.get("knowledge_base_id"):
raise ValueError("knowledge_base_id is required")
if not parameters.get("query"):
raise ValueError("query is required")
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")

View File

@ -0,0 +1,87 @@
identity:
name: bedrock_retrieve
author: AWS
label:
en_US: Bedrock Retrieve
zh_Hans: Bedrock检索
pt_BR: Bedrock Retrieve
icon: icon.svg
description:
human:
en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明
pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base.
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
parameters:
- name: knowledge_base_id
type: string
required: true
label:
en_US: Bedrock Knowledge Base ID
zh_Hans: Bedrock知识库ID
pt_BR: Bedrock Knowledge Base ID
human_description:
en_US: ID of the Bedrock Knowledge Base to retrieve from
zh_Hans: 用于检索的Bedrock知识库ID
pt_BR: ID of the Bedrock Knowledge Base to retrieve from
llm_description: ID of the Bedrock Knowledge Base to retrieve from
form: form
- name: query
type: string
required: true
label:
en_US: Query string
zh_Hans: 查询语句
pt_BR: Query string
human_description:
en_US: The search query to retrieve relevant information
zh_Hans: 用于检索相关信息的查询语句
pt_BR: The search query to retrieve relevant information
llm_description: The search query to retrieve relevant information
form: llm
- name: topk
type: number
required: false
form: form
label:
en_US: Limit for results count
zh_Hans: 返回结果数量限制
pt_BR: Limit for results count
human_description:
en_US: Maximum number of results to return
zh_Hans: 最大返回结果数量
pt_BR: Maximum number of results to return
min: 1
max: 10
default: 5
- name: aws_region
type: string
required: false
label:
en_US: AWS Region
zh_Hans: AWS 区域
pt_BR: AWS Region
human_description:
en_US: AWS region where the Bedrock Knowledge Base is located
zh_Hans: Bedrock知识库所在的AWS区域
pt_BR: AWS region where the Bedrock Knowledge Base is located
llm_description: AWS region where the Bedrock Knowledge Base is located
form: form
- name: metadata_filter
type: string
required: false
label:
en_US: Metadata Filter
zh_Hans: 元数据过滤器
pt_BR: Metadata Filter
human_description:
en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
zh_Hans: '元数据的JSON格式过滤条件例如{{"greaterThan": {"key: "aaa", "value": 10}}'
pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
form: form

View File

@ -0,0 +1,357 @@
import base64
import json
import logging
import re
from datetime import datetime
from typing import Any, Union
from urllib.parse import urlparse
import boto3
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class NovaCanvasTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke AWS Bedrock Nova Canvas model for image generation
"""
# Get common parameters
prompt = tool_parameters.get("prompt", "")
image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip()
if not prompt:
return self.create_text_message("Please provide a text prompt for image generation.")
if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3":
return self.create_text_message("Please provide an valid S3 URI for image output.")
task_type = tool_parameters.get("task_type", "TEXT_IMAGE")
aws_region = tool_parameters.get("aws_region", "us-east-1")
# Get common image generation config parameters
width = tool_parameters.get("width", 1024)
height = tool_parameters.get("height", 1024)
cfg_scale = tool_parameters.get("cfg_scale", 8.0)
negative_prompt = tool_parameters.get("negative_prompt", "")
seed = tool_parameters.get("seed", 0)
quality = tool_parameters.get("quality", "standard")
# Handle S3 image if provided
image_input_s3uri = tool_parameters.get("image_input_s3uri", "")
if task_type != "TEXT_IMAGE":
if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3":
return self.create_text_message("Please provide a valid S3 URI for image to image generation.")
# Parse S3 URI
parsed_uri = urlparse(image_input_s3uri)
bucket = parsed_uri.netloc
key = parsed_uri.path.lstrip("/")
# Initialize S3 client and download image
s3_client = boto3.client("s3")
response = s3_client.get_object(Bucket=bucket, Key=key)
image_data = response["Body"].read()
# Base64 encode the image
input_image = base64.b64encode(image_data).decode("utf-8")
try:
# Initialize Bedrock client
bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region)
# Base image generation config
image_generation_config = {
"width": width,
"height": height,
"cfgScale": cfg_scale,
"seed": seed,
"numberOfImages": 1,
"quality": quality,
}
# Prepare request body based on task type
body = {"imageGenerationConfig": image_generation_config}
if task_type == "TEXT_IMAGE":
body["taskType"] = "TEXT_IMAGE"
body["textToImageParams"] = {"text": prompt}
if negative_prompt:
body["textToImageParams"]["negativeText"] = negative_prompt
elif task_type == "COLOR_GUIDED_GENERATION":
colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680")
if not self._validate_color_string(colors):
return self.create_text_message("Please provide valid colors in hexadecimal format.")
body["taskType"] = "COLOR_GUIDED_GENERATION"
body["colorGuidedGenerationParams"] = {
"colors": colors.split("-"),
"referenceImage": input_image,
"text": prompt,
}
if negative_prompt:
body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt
elif task_type == "IMAGE_VARIATION":
similarity_strength = tool_parameters.get("similarity_strength", 0.5)
body["taskType"] = "IMAGE_VARIATION"
body["imageVariationParams"] = {
"images": [input_image],
"similarityStrength": similarity_strength,
"text": prompt,
}
if negative_prompt:
body["imageVariationParams"]["negativeText"] = negative_prompt
elif task_type == "INPAINTING":
mask_prompt = tool_parameters.get("mask_prompt")
if not mask_prompt:
return self.create_text_message("Please provide a mask prompt for image inpainting.")
body["taskType"] = "INPAINTING"
body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt}
if negative_prompt:
body["inPaintingParams"]["negativeText"] = negative_prompt
elif task_type == "OUTPAINTING":
mask_prompt = tool_parameters.get("mask_prompt")
if not mask_prompt:
return self.create_text_message("Please provide a mask prompt for image outpainting.")
outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT")
body["taskType"] = "OUTPAINTING"
body["outPaintingParams"] = {
"image": input_image,
"maskPrompt": mask_prompt,
"outPaintingMode": outpainting_mode,
"text": prompt,
}
if negative_prompt:
body["outPaintingParams"]["negativeText"] = negative_prompt
elif task_type == "BACKGROUND_REMOVAL":
body["taskType"] = "BACKGROUND_REMOVAL"
body["backgroundRemovalParams"] = {"image": input_image}
else:
return self.create_text_message(f"Unsupported task type: {task_type}")
# Call Nova Canvas model
response = bedrock.invoke_model(
body=json.dumps(body),
modelId="amazon.nova-canvas-v1:0",
accept="application/json",
contentType="application/json",
)
# Process response
response_body = json.loads(response.get("body").read())
if response_body.get("error"):
raise Exception(f"Error in model response: {response_body.get('error')}")
base64_image = response_body.get("images")[0]
# Upload to S3 if image_output_s3uri is provided
try:
# Parse S3 URI for output
parsed_uri = urlparse(image_output_s3uri)
output_bucket = parsed_uri.netloc
output_base_path = parsed_uri.path.lstrip("/")
# Generate filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_key = f"{output_base_path}/canvas-output-{timestamp}.png"
# Initialize S3 client if not already done
s3_client = boto3.client("s3", region_name=aws_region)
# Decode base64 image and upload to S3
image_data = base64.b64decode(base64_image)
s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png")
logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}")
except Exception as e:
logger.exception("Failed to upload image to S3")
# Return image
return [
self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"),
self.create_blob_message(
blob=base64.b64decode(base64_image),
meta={"mime_type": "image/png"},
save_as=self.VariableKey.IMAGE.value,
),
]
except Exception as e:
return self.create_text_message(f"Failed to generate image: {str(e)}")
def _validate_color_string(self, color_string) -> bool:
color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$"
if re.match(color_pattern, color_string):
return True
return False
def get_runtime_parameters(self) -> list[ToolParameter]:
parameters = [
ToolParameter(
name="prompt",
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Text description of the image you want to generate or modify",
zh_Hans="您想要生成或修改的图像的文本描述",
),
llm_description="Describe the image you want to generate or how you want to modify the input image",
),
ToolParameter(
name="image_input_s3uri",
label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"),
),
ToolParameter(
name="image_output_s3uri",
label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI"
),
),
ToolParameter(
name="width",
label=I18nObject(en_US="Width", zh_Hans="宽度"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=1024,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"),
),
ToolParameter(
name="height",
label=I18nObject(en_US="Height", zh_Hans="高度"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=1024,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"),
),
ToolParameter(
name="cfg_scale",
label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=8.0,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词"
),
),
ToolParameter(
name="negative_prompt",
label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="",
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容"
),
),
ToolParameter(
name="seed",
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=0,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"),
),
ToolParameter(
name="aws_region",
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="us-east-1",
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
),
ToolParameter(
name="task_type",
label=I18nObject(en_US="Task Type", zh_Hans="任务类型"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="TEXT_IMAGE",
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"),
),
ToolParameter(
name="quality",
label=I18nObject(en_US="Quality", zh_Hans="质量"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="standard",
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)"
),
),
ToolParameter(
name="colors",
label=I18nObject(en_US="Colors", zh_Hans="颜色"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680",
zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680",
),
),
ToolParameter(
name="similarity_strength",
label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=0.5,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="How similar the generated image should be to the input image (0.0 to 1.0)",
zh_Hans="生成的图像应该与输入图像的相似程度0.0到1.0",
),
),
ToolParameter(
name="mask_prompt",
label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Text description to generate mask for inpainting/outpainting",
zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述",
),
),
ToolParameter(
name="outpainting_mode",
label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="DEFAULT",
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="Mode for outpainting (DEFAULT or other supported modes)",
zh_Hans="外补绘制的模式DEFAULT或其他支持的模式",
),
),
]
return parameters

View File

@ -0,0 +1,175 @@
identity:
name: nova_canvas
author: AWS
label:
en_US: AWS Bedrock Nova Canvas
zh_Hans: AWS Bedrock Nova Canvas
icon: icon.svg
description:
human:
en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html
zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。
llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal.
parameters:
- name: task_type
type: string
required: false
default: TEXT_IMAGE
label:
en_US: Task Type
zh_Hans: 任务类型
human_description:
en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL)
zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除)
form: llm
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
human_description:
en_US: Text description of the image you want to generate or modify
zh_Hans: 您想要生成或修改的图像的文本描述
llm_description: Describe the image you want to generate or how you want to modify the input image
form: llm
- name: image_input_s3uri
type: string
required: false
label:
en_US: Input image s3 uri
zh_Hans: 输入图片的s3 uri
human_description:
en_US: The input image to modify (required for all modes except TEXT_IMAGE)
zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要)
llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE.
form: llm
- name: image_output_s3uri
type: string
required: true
label:
en_US: Output S3 URI
zh_Hans: 输出S3 URI
human_description:
en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png
zh_Hans: 生成的图像将保存到的S3 URI。如果提供图像将以canvas-output-{timestamp}.png的格式上传
llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename.
form: form
- name: negative_prompt
type: string
required: false
label:
en_US: Negative Prompt
zh_Hans: 负面提示词
human_description:
en_US: Things you don't want in the generated image
zh_Hans: 您不想在生成的图像中出现的内容
form: llm
- name: width
type: number
required: false
label:
en_US: Width
zh_Hans: 宽度
human_description:
en_US: Width of the generated image
zh_Hans: 生成图像的宽度
form: form
default: 1024
- name: height
type: number
required: false
label:
en_US: Height
zh_Hans: 高度
human_description:
en_US: Height of the generated image
zh_Hans: 生成图像的高度
form: form
default: 1024
- name: cfg_scale
type: number
required: false
label:
en_US: CFG Scale
zh_Hans: CFG比例
human_description:
en_US: How strongly the image should conform to the prompt
zh_Hans: 图像应该多大程度上符合提示词
form: form
default: 8.0
- name: seed
type: number
required: false
label:
en_US: Seed
zh_Hans: 种子值
human_description:
en_US: Random seed for image generation
zh_Hans: 图像生成的随机种子
form: form
default: 0
- name: aws_region
type: string
required: false
default: us-east-1
label:
en_US: AWS Region
zh_Hans: AWS 区域
human_description:
en_US: AWS region for Bedrock service
zh_Hans: Bedrock 服务的 AWS 区域
form: form
- name: quality
type: string
required: false
default: standard
label:
en_US: Quality
zh_Hans: 质量
human_description:
en_US: Quality of the generated image (standard or premium)
zh_Hans: 生成图像的质量(标准或高级)
form: form
- name: colors
type: string
required: false
label:
en_US: Colors
zh_Hans: 颜色
human_description:
en_US: List of colors for color-guided generation
zh_Hans: 颜色引导生成的颜色列表
form: form
- name: similarity_strength
type: number
required: false
default: 0.5
label:
en_US: Similarity Strength
zh_Hans: 相似度强度
human_description:
en_US: How similar the generated image should be to the input image (0.0 to 1.0)
zh_Hans: 生成的图像应该与输入图像的相似程度0.0到1.0
form: form
- name: mask_prompt
type: string
required: false
label:
en_US: Mask Prompt
zh_Hans: 蒙版提示词
human_description:
en_US: Text description to generate mask for inpainting/outpainting
zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述
form: llm
- name: outpainting_mode
type: string
required: false
default: DEFAULT
label:
en_US: Outpainting Mode
zh_Hans: 外补绘制模式
human_description:
en_US: Mode for outpainting (DEFAULT or other supported modes)
zh_Hans: 外补绘制的模式DEFAULT或其他支持的模式
form: form

View File

@ -0,0 +1,371 @@
import base64
import logging
import time
from io import BytesIO
from typing import Any, Optional, Union
from urllib.parse import urlparse
import boto3
from botocore.exceptions import ClientError
from PIL import Image
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
NOVA_REEL_DEFAULT_REGION = "us-east-1"
NOVA_REEL_DEFAULT_DIMENSION = "1280x720"
NOVA_REEL_DEFAULT_FPS = 24
NOVA_REEL_DEFAULT_DURATION = 6
NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0"
NOVA_REEL_STATUS_CHECK_INTERVAL = 5
# Image requirements
NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280
NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720
NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB"
class NovaReelTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke AWS Bedrock Nova Reel model for video generation.
Args:
user_id: The ID of the user making the request
tool_parameters: Dictionary containing the tool parameters
Returns:
ToolInvokeMessage containing either the video content or status information
"""
try:
# Validate and extract parameters
params = self._validate_and_extract_parameters(tool_parameters)
if isinstance(params, ToolInvokeMessage):
return params
# Initialize AWS clients
bedrock, s3_client = self._initialize_aws_clients(params["aws_region"])
# Prepare model input
model_input = self._prepare_model_input(params, s3_client)
if isinstance(model_input, ToolInvokeMessage):
return model_input
# Start video generation
invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"])
invocation_arn = invocation["invocationArn"]
# Handle async/sync mode
return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"])
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
error_message = e.response.get("Error", {}).get("Message", str(e))
logger.exception(f"AWS API error: {error_code} - {error_message}")
return self.create_text_message(f"AWS service error: {error_code} - {error_message}")
except Exception as e:
logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True)
return self.create_text_message(f"Failed to generate video: {str(e)}")
def _validate_and_extract_parameters(
self, tool_parameters: dict[str, Any]
) -> Union[dict[str, Any], ToolInvokeMessage]:
"""Validate and extract parameters from the input dictionary."""
prompt = tool_parameters.get("prompt", "")
video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip()
# Validate required parameters
if not prompt:
return self.create_text_message("Please provide a text prompt for video generation.")
if not video_output_s3uri:
return self.create_text_message("Please provide an S3 URI for video output.")
# Validate S3 URI format
if not video_output_s3uri.startswith("s3://"):
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
# Ensure S3 URI ends with '/'
video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/"
return {
"prompt": prompt,
"video_output_s3uri": video_output_s3uri,
"image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(),
"aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION),
"dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION),
"seed": int(tool_parameters.get("seed", 0)),
"fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)),
"duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)),
"async_mode": bool(tool_parameters.get("async", True)),
}
def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]:
"""Initialize AWS Bedrock and S3 clients."""
bedrock = boto3.client(service_name="bedrock-runtime", region_name=region)
s3_client = boto3.client("s3", region_name=region)
return bedrock, s3_client
def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]:
"""Prepare the input for the Nova Reel model."""
model_input = {
"taskType": "TEXT_VIDEO",
"textToVideoParams": {"text": params["prompt"]},
"videoGenerationConfig": {
"durationSeconds": params["duration"],
"fps": params["fps"],
"dimension": params["dimension"],
"seed": params["seed"],
},
}
# Add image if provided
if params["image_input_s3uri"]:
try:
image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"])
if not image_data:
return self.create_text_message("Failed to retrieve image from S3")
# Process and validate image
processed_image = self._process_and_validate_image(image_data)
if isinstance(processed_image, ToolInvokeMessage):
return processed_image
# Convert processed image to base64
img_buffer = BytesIO()
processed_image.save(img_buffer, format="PNG")
img_buffer.seek(0)
input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
model_input["textToVideoParams"]["images"] = [
{"format": "png", "source": {"bytes": input_image_base64}}
]
except Exception as e:
logger.error(f"Error processing input image: {str(e)}", exc_info=True)
return self.create_text_message(f"Failed to process input image: {str(e)}")
return model_input
def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]:
"""
Process and validate the input image according to Nova Reel requirements.
Requirements:
- Must be 1280x720 pixels
- Must be RGB format (8 bits per channel)
- If PNG, alpha channel must not have transparent/translucent pixels
"""
try:
# Open image
img = Image.open(BytesIO(image_data))
# Convert RGBA to RGB if needed, ensuring no transparency
if img.mode == "RGBA":
# Check for transparency
if img.getchannel("A").getextrema()[0] < 255:
return self.create_text_message(
"PNG image contains transparent or translucent pixels, which is not supported. "
"Please provide an image without transparency."
)
# Convert to RGB
img = img.convert("RGB")
elif img.mode != "RGB":
# Convert any other mode to RGB
img = img.convert("RGB")
# Validate/adjust dimensions
if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT):
logger.warning(
f"Image dimensions {img.size} do not match required dimensions "
f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..."
)
img = img.resize(
(NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS
)
# Validate bit depth
if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE:
return self.create_text_message(
f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel"
)
return img
except Exception as e:
logger.error(f"Error processing image: {str(e)}", exc_info=True)
return self.create_text_message(
"Failed to process image. Please ensure the image is a valid JPEG or PNG file."
)
def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]:
"""Download and return image data from S3."""
parsed_uri = urlparse(s3_uri)
bucket = parsed_uri.netloc
key = parsed_uri.path.lstrip("/")
response = s3_client.get_object(Bucket=bucket, Key=key)
return response["Body"].read()
def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]:
"""Start the async video generation process."""
return bedrock.start_async_invoke(
modelId=NOVA_REEL_MODEL_ID,
modelInput=model_input,
outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}},
)
def _handle_generation_mode(
self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool
) -> ToolInvokeMessage:
"""Handle async or sync video generation mode."""
invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
video_uri = f"{video_path}/output.mp4"
if async_mode:
return self.create_text_message(
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
f"Video will be available at: {video_uri}"
)
return self._wait_for_completion(bedrock, s3_client, invocation_arn)
def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage:
"""Wait for video generation completion and handle the result."""
while True:
status_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
status = status_response["status"]
video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
if status == "Completed":
return self._handle_completed_video(s3_client, video_path)
elif status == "Failed":
failure_message = status_response.get("failureMessage", "Unknown error")
return self.create_text_message(f"Video generation failed.\nError: {failure_message}")
elif status == "InProgress":
time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL)
else:
return self.create_text_message(f"Unexpected status: {status}")
def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage:
"""Handle completed video generation and return the result."""
parsed_uri = urlparse(video_path)
bucket = parsed_uri.netloc
key = parsed_uri.path.lstrip("/") + "/output.mp4"
try:
response = s3_client.get_object(Bucket=bucket, Key=key)
video_content = response["Body"].read()
return [
self.create_text_message(f"Video is available at: {video_path}/output.mp4"),
self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"),
]
except Exception as e:
logger.error(f"Error downloading video: {str(e)}", exc_info=True)
return self.create_text_message(
f"Video generation completed but failed to download video: {str(e)}\n"
f"Video is available at: s3://{bucket}/{key}"
)
def get_runtime_parameters(self) -> list[ToolParameter]:
"""Define the tool's runtime parameters."""
parameters = [
ToolParameter(
name="prompt",
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述"
),
llm_description="Describe the video you want to generate",
),
ToolParameter(
name="video_output_s3uri",
label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI"
),
),
ToolParameter(
name="dimension",
label=I18nObject(en_US="Dimension", zh_Hans="尺寸"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default=NOVA_REEL_DEFAULT_DIMENSION,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"),
),
ToolParameter(
name="duration",
label=I18nObject(en_US="Duration", zh_Hans="时长"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=NOVA_REEL_DEFAULT_DURATION,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"),
),
ToolParameter(
name="seed",
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=0,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"),
),
ToolParameter(
name="fps",
label=I18nObject(en_US="FPS", zh_Hans="帧率"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=NOVA_REEL_DEFAULT_FPS,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数"
),
),
ToolParameter(
name="async",
label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"),
type=ToolParameter.ToolParameterType.BOOLEAN,
required=False,
default=True,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)",
zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)",
),
),
ToolParameter(
name="aws_region",
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default=NOVA_REEL_DEFAULT_REGION,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
),
ToolParameter(
name="image_input_s3uri",
label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame",
zh_Hans="用作第一帧的输入图像1280x720 JPEG/PNG的S3 URI",
),
),
]
return parameters

View File

@ -0,0 +1,124 @@
identity:
name: nova_reel
author: AWS
label:
en_US: AWS Bedrock Nova Reel
zh_Hans: AWS Bedrock Nova Reel
icon: icon.svg
description:
human:
en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution.
parameters:
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
human_description:
en_US: Text description of the video you want to generate
zh_Hans: 您想要生成的视频的文本描述
llm_description: Describe the video you want to generate
form: llm
- name: video_output_s3uri
type: string
required: true
label:
en_US: Output S3 URI
zh_Hans: 输出S3 URI
human_description:
en_US: S3 URI where the generated video will be stored
zh_Hans: 生成的视频将存储的S3 URI
form: form
- name: dimension
type: string
required: false
default: 1280x720
label:
en_US: Dimension
zh_Hans: 尺寸
human_description:
en_US: Video dimensions (width x height)
zh_Hans: 视频尺寸(宽 x 高)
form: form
- name: duration
type: number
required: false
default: 6
label:
en_US: Duration
zh_Hans: 时长
human_description:
en_US: Video duration in seconds
zh_Hans: 视频时长(秒)
form: form
- name: seed
type: number
required: false
default: 0
label:
en_US: Seed
zh_Hans: 种子值
human_description:
en_US: Random seed for video generation
zh_Hans: 视频生成的随机种子
form: form
- name: fps
type: number
required: false
default: 24
label:
en_US: FPS
zh_Hans: 帧率
human_description:
en_US: Frames per second for the generated video
zh_Hans: 生成视频的每秒帧数
form: form
- name: async
type: boolean
required: false
default: true
label:
en_US: Async Mode
zh_Hans: 异步模式
human_description:
en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion)
zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成)
form: llm
- name: aws_region
type: string
required: false
default: us-east-1
label:
en_US: AWS Region
zh_Hans: AWS 区域
human_description:
en_US: AWS region for Bedrock service
zh_Hans: Bedrock 服务的 AWS 区域
form: form
- name: image_input_s3uri
type: string
required: false
label:
en_US: Input Image S3 URI
zh_Hans: 输入图像S3 URI
human_description:
en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame
zh_Hans: 用作第一帧的输入图像1280x720 JPEG/PNG的S3 URI
form: llm
development:
dependencies:
- boto3
- pillow

View File

@ -0,0 +1,80 @@
from typing import Any, Union
from urllib.parse import urlparse
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class S3Operator(BuiltinTool):
s3_client: Any = None
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
# Initialize S3 client if not already done
if not self.s3_client:
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.s3_client = boto3.client("s3", region_name=aws_region)
else:
self.s3_client = boto3.client("s3")
# Parse S3 URI
s3_uri = tool_parameters.get("s3_uri")
if not s3_uri:
return self.create_text_message("s3_uri parameter is required")
parsed_uri = urlparse(s3_uri)
if parsed_uri.scheme != "s3":
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
bucket = parsed_uri.netloc
# Remove leading slash from key
key = parsed_uri.path.lstrip("/")
operation_type = tool_parameters.get("operation_type", "read")
generate_presign_url = tool_parameters.get("generate_presign_url", False)
presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour
if operation_type == "write":
text_content = tool_parameters.get("text_content")
if not text_content:
return self.create_text_message("text_content parameter is required for write operation")
# Write content to S3
self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8"))
result = f"s3://{bucket}/{key}"
# Generate presigned URL for the written object if requested
if generate_presign_url:
result = self.s3_client.generate_presigned_url(
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
)
else: # read operation
# Get object from S3
response = self.s3_client.get_object(Bucket=bucket, Key=key)
result = response["Body"].read().decode("utf-8")
# Generate presigned URL if requested
if generate_presign_url:
result = self.s3_client.generate_presigned_url(
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
)
return self.create_text_message(text=result)
except self.s3_client.exceptions.NoSuchBucket:
return self.create_text_message(f"Bucket '{bucket}' does not exist")
except self.s3_client.exceptions.NoSuchKey:
return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'")
except Exception as e:
return self.create_text_message(f"Exception: {str(e)}")

View File

@ -0,0 +1,98 @@
identity:
name: s3_operator
author: AWS
label:
en_US: AWS S3 Operator
zh_Hans: AWS S3 读写器
pt_BR: AWS S3 Operator
icon: icon.svg
description:
human:
en_US: AWS S3 Writer and Reader
zh_Hans: 读写S3 bucket中的文件
pt_BR: AWS S3 Writer and Reader
llm: AWS S3 Writer and Reader
parameters:
- name: text_content
type: string
required: false
label:
en_US: The text to write
zh_Hans: 待写入的文本
pt_BR: The text to write
human_description:
en_US: The text to write
zh_Hans: 待写入的文本
pt_BR: The text to write
llm_description: The text to write
form: llm
- name: s3_uri
type: string
required: true
label:
en_US: s3 uri
zh_Hans: s3 uri
pt_BR: s3 uri
human_description:
en_US: s3 uri
zh_Hans: s3 uri
pt_BR: s3 uri
llm_description: s3 uri
form: llm
- name: aws_region
type: string
required: true
label:
en_US: region of bucket
zh_Hans: bucket 所在的region
pt_BR: region of bucket
human_description:
en_US: region of bucket
zh_Hans: bucket 所在的region
pt_BR: region of bucket
llm_description: region of bucket
form: form
- name: operation_type
type: select
required: true
label:
en_US: operation type
zh_Hans: 操作类型
pt_BR: operation type
human_description:
en_US: operation type
zh_Hans: 操作类型
pt_BR: operation type
default: read
options:
- value: read
label:
en_US: read
zh_Hans:
- value: write
label:
en_US: write
zh_Hans:
form: form
- name: generate_presign_url
type: boolean
required: false
label:
en_US: Generate presigned URL
zh_Hans: 生成预签名URL
human_description:
en_US: Whether to generate a presigned URL for the S3 object
zh_Hans: 是否生成S3对象的预签名URL
default: false
form: form
- name: presign_expiry
type: number
required: false
label:
en_US: Presigned URL expiration time
zh_Hans: 预签名URL有效期
human_description:
en_US: Expiration time in seconds for the presigned URL
zh_Hans: 预签名URL的有效期
default: 3600
form: form

View File

@ -1,32 +1,8 @@
from typing import Any
from core.file import FileTransferMethod, FileType
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from factories import file_factory
class VectorizerProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
mapping = {
"transfer_method": FileTransferMethod.TOOL_FILE,
"type": FileType.IMAGE,
"id": "test_id",
"url": "https://cloud.dify.ai/logo/logo-site.png",
}
test_img = file_factory.build_from_mapping(
mapping=mapping,
tenant_id="__test_123",
)
try:
VectorizerTool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).invoke(
user_id="",
tool_parameters={"mode": "test", "image": test_img},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
return

View File

@ -210,7 +210,7 @@ class ApiTool(Tool):
)
return response
else:
raise ValueError(f"Invalid http method {self.method}")
raise ValueError(f"Invalid http method {method}")
def _convert_body_property_any_of(
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10

View File

@ -8,9 +8,10 @@ from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
from httpx import get
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import MessageFile
@ -94,12 +95,11 @@ class ToolFileManager:
) -> ToolFile:
# try to download image
try:
response = get(file_url)
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except Exception as e:
logger.exception(f"Failed to download file from {file_url}")
raise
except httpx.TimeoutException as e:
raise ValueError(f"timeout when downloading file from {file_url}")
mimetype = guess_type(file_url)[0] or "octet/stream"
extension = guess_extension(mimetype) or ".bin"

View File

@ -21,6 +21,7 @@ from .variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
ArrayVariable,
FileVariable,
FloatVariable,
IntegerVariable,
@ -43,6 +44,7 @@ __all__ = [
"ArraySegment",
"ArrayStringSegment",
"ArrayStringVariable",
"ArrayVariable",
"FileSegment",
"FileVariable",
"FloatSegment",

View File

@ -10,6 +10,7 @@ from .segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArraySegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
@ -52,19 +53,23 @@ class ObjectVariable(ObjectSegment, Variable):
pass
class ArrayAnyVariable(ArrayAnySegment, Variable):
class ArrayVariable(ArraySegment, Variable):
pass
class ArrayStringVariable(ArrayStringSegment, Variable):
class ArrayAnyVariable(ArrayAnySegment, ArrayVariable):
pass
class ArrayNumberVariable(ArrayNumberSegment, Variable):
class ArrayStringVariable(ArrayStringSegment, ArrayVariable):
pass
class ArrayObjectVariable(ArrayObjectSegment, Variable):
class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable):
pass
class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable):
pass

View File

@ -33,7 +33,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: Optional[int] = Field(description="exception count", default=0)
exceptions_count: int = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
@ -97,11 +97,10 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunRetryEvent(BaseNodeEvent):
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
start_at: datetime = Field(..., description="retry start time")
start_index: int = Field(..., description="retry start index")
###########################################

View File

@ -4,6 +4,7 @@ from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from configs import dify_config
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
@ -170,7 +171,9 @@ class Graph(BaseModel):
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id
parallel_mapping=parallel_mapping,
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
parent_parallel_id=parallel.parent_parallel_id,
)
# init answer stream generate routes

View File

@ -641,7 +641,6 @@ class GraphEngine:
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
if node_instance.should_retry and retries < max_retries:
retries += 1
self.graph_runtime_state.node_run_steps += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=node_instance.id,
@ -649,14 +648,14 @@ class GraphEngine:
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
error=run_result.error,
retry_index=retries,
predecessor_node_id=node_instance.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=run_result.error,
retry_index=retries,
start_at=retry_start_at,
start_index=self.graph_runtime_state.node_run_steps,
)
time.sleep(retry_interval)
continue

View File

@ -147,6 +147,8 @@ class AnswerStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (

View File

@ -60,7 +60,6 @@ class AnswerStreamProcessor(StreamProcessor):
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs

View File

@ -1,3 +1,4 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator
@ -5,6 +6,8 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
logger = logging.getLogger(__name__)
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
@ -31,13 +34,22 @@ class StreamProcessor(ABC):
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
if finished_node_id not in self.graph.edge_mapping:
logger.warning(f"node {finished_node_id} has no edge mapping")
return
for edge in self.graph.edge_mapping[finished_node_id]:
if (
edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify
):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
# remove unreachable nodes
# FIXME: because of the code branch can combine directly, so for answer node
# we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564
# reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)

View File

@ -1,4 +1,4 @@
class BaseNodeError(Exception):
class BaseNodeError(ValueError):
"""Base class for node errors."""
pass

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Optional
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]):
)
# Transform result
result = self._transform_result(result, self.node_data.outputs)
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _check_string(self, value: str, variable: str) -> str:
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if not isinstance(value, str):
if value is None:
return None
else:
raise OutputValidationError(f"Output variable `{variable}` must be a string")
raise OutputValidationError(f"Output variable `{variable}` must be a string")
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise OutputValidationError(
@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]):
return value.replace("\x00", "")
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
"""
Check number
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if not isinstance(value, int | float):
if value is None:
return None
else:
raise OutputValidationError(f"Output variable `{variable}` must be a number")
raise OutputValidationError(f"Output variable `{variable}` must be a number")
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise OutputValidationError(
@ -118,14 +116,12 @@ class CodeNode(BaseNode[CodeNodeData]):
return value
def _transform_result(
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
) -> dict:
"""
Transform result
:param result: result
:param output_schema: output schema
:return:
"""
self,
result: Mapping[str, Any],
output_schema: Optional[dict[str, CodeNodeData.Output]],
prefix: str = "",
depth: int = 1,
):
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")

View File

@ -1,6 +1,7 @@
import csv
import io
import json
import logging
import os
import tempfile
@ -8,12 +9,6 @@ import docx
import pandas as pd
import pypdfium2 # type: ignore
import yaml # type: ignore
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
@ -28,6 +23,8 @@ from models.workflow import WorkflowNodeExecutionStatus
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
"""
@ -183,10 +180,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
def _extract_text_from_doc(file_content: bytes) -> str:
"""
Extract text from a DOC/DOCX file.
For now support only paragraph and table add more if needed
"""
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
text = []
# Process paragraphs
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text.append(paragraph.text)
# Process tables
for table in doc.tables:
# Table header
try:
# table maybe cause errors so ignore it.
if len(table.rows) > 0 and table.rows[0].cells is not None:
# Check if any cell in the table has text
has_content = False
for row in table.rows:
if any(cell.text.strip() for cell in row.cells):
has_content = True
break
if has_content:
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
for row in table.rows[1:]:
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
text.append(markdown_table)
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue
return "\n".join(text)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
@ -256,6 +286,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
def _extract_text_from_ppt(file_content: bytes) -> str:
from unstructured.partition.ppt import partition_ppt
try:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
@ -265,6 +297,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
def _extract_text_from_pptx(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx
try:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
@ -287,6 +322,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
def _extract_text_from_epub(file_content: bytes) -> str:
from unstructured.partition.epub import partition_epub
try:
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
@ -296,6 +333,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
def _extract_text_from_eml(file_content: bytes) -> str:
from unstructured.partition.email import partition_email
try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
@ -305,6 +344,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
def _extract_text_from_msg(file_content: bytes) -> str:
from unstructured.partition.msg import partition_msg
try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)

View File

@ -135,6 +135,8 @@ class EndStreamGeneratorRouter:
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.IF_ELSE.value,

View File

@ -35,4 +35,4 @@ class FailBranchSourceHandle(StrEnum):
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]
RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE

View File

@ -39,15 +39,9 @@ class RunRetryEvent(BaseModel):
start_at: datetime = Field(..., description="Retry start time")
class SingleStepRetryEvent(BaseModel):
class SingleStepRetryEvent(NodeRunResult):
"""Single step retry event"""
status: str = WorkflowNodeExecutionStatus.RETRY.value
inputs: dict | None = Field(..., description="input")
error: str = Field(..., description="error")
outputs: dict = Field(..., description="output")
retry_index: int = Field(..., description="Retry attempt number")
error: str = Field(..., description="error")
elapsed_time: float = Field(..., description="elapsed time")
execution_metadata: dict | None = Field(..., description="execution metadata")

View File

@ -16,3 +16,7 @@ class InvalidHttpMethodError(HttpRequestNodeError):
class ResponseSizeError(HttpRequestNodeError):
"""Raised when the response size exceeds the allowed threshold."""
class RequestBodyError(HttpRequestNodeError):
"""Raised when the request body is invalid."""

View File

@ -23,6 +23,7 @@ from .exc import (
FileFetchError,
HttpRequestNodeError,
InvalidHttpMethodError,
RequestBodyError,
ResponseSizeError,
)
@ -143,13 +144,19 @@ class Executor:
case "none":
self.content = ""
case "raw-text":
if len(data) != 1:
raise RequestBodyError("raw-text body type should have exactly one item")
self.content = self.variable_pool.convert_template(data[0].value).text
case "json":
if len(data) != 1:
raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text
json_object = json.loads(json_string, strict=False)
self.json = json_object
# self.json = self._parse_object_contains_variables(json_object)
case "binary":
if len(data) != 1:
raise RequestBodyError("binary body type should have exactly one item")
file_selector = data[0].file
file_variable = self.variable_pool.get_file(file_selector)
if file_variable is None:
@ -249,7 +256,7 @@ class Executor:
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response = getattr(ssrf_proxy, self.method)(**request_args)
except ssrf_proxy.MaxRetriesExceededError as e:
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e))
return response
@ -317,6 +324,8 @@ class Executor:
elif self.json:
body = json.dumps(self.json)
elif self.node_data.body.type == "raw-text":
if len(self.node_data.body.data) != 1:
raise RequestBodyError("raw-text body type should have exactly one item")
body = self.node_data.body.data[0].value
if body:
raw += f"Content-Length: {len(body)}\r\n"

View File

@ -20,7 +20,7 @@ from .entities import (
HttpRequestNodeTimeout,
Response,
)
from .exc import HttpRequestNodeError
from .exc import HttpRequestNodeError, RequestBodyError
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
@ -136,9 +136,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
data = node_data.body.data
match body_type:
case "binary":
if len(data) != 1:
raise RequestBodyError("invalid body data, should have only one item")
selector = data[0].file
selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector))
case "json" | "raw-text":
if len(data) != 1:
raise RequestBodyError("invalid body data, should have only one item")
selectors += variable_template_parser.extract_selectors_from_template(data[0].key)
selectors += variable_template_parser.extract_selectors_from_template(data[0].value)
case "x-www-form-urlencoded":

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.variables import IntegerVariable
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
@ -75,12 +75,15 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
Run the node.
"""
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not iterator_list_segment:
raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
if len(iterator_list_segment.value) == 0:
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -89,7 +92,7 @@ class IterationNode(BaseNode[IterationNodeData]):
)
return
iterator_list_value = iterator_list_segment.to_object()
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")

View File

@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: Optional[str] = None

View File

@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages(
user_query=query,
user_files=files,
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
@ -545,8 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages(
self,
*,
user_query: str | None = None,
user_files: Sequence["File"],
sys_query: str | None = None,
sys_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@ -562,7 +562,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@ -581,14 +581,14 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
if sys_query:
message = LLMNodeChatModelMessage(
text=user_query,
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@ -635,24 +635,27 @@ class LLMNode(BaseNode[LLMNodeData]):
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if user_query:
if sys_query:
if prompt_content_type == str:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = user_query + "\n" + content_item.data
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
if vision_enabled and user_files:
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in user_files:
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
@ -662,7 +665,7 @@ class LLMNode(BaseNode[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
@ -846,6 +849,68 @@ class LLMNode(BaseNode[LLMNodeData]):
},
}
def _handle_list_messages(
self,
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role:
@ -880,68 +945,6 @@ def _render_jinja2_message(
return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:

View File

@ -179,6 +179,15 @@ class ParameterExtractorNode(LLMNode):
error=str(e),
metadata={},
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
error=str(e),
metadata={},
)
error = None

View File

@ -1,10 +1,8 @@
import json
import logging
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.llm_generator.output_parser.errors import OutputParserError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
@ -86,37 +84,38 @@ class QuestionClassifierNode(LLMNode):
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
user_query=query,
sys_query=query,
memory=memory,
model_config=model_config,
user_files=files,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
)
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
try:
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:
@ -127,10 +126,6 @@ class QuestionClassifierNode(LLMNode):
if category_id_result in category_ids:
category_name = classes_map[category_id_result]
category_id = category_id_result
except OutputParserError:
logging.exception(f"Failed to parse result text: {result_text}")
try:
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
@ -154,7 +149,6 @@ class QuestionClassifierNode(LLMNode):
},
llm_usage=usage,
)
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,

View File

@ -1,5 +1,6 @@
from collections.abc import Mapping, Sequence
from typing import Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -231,6 +232,10 @@ class ToolNode(BaseNode[ToolNodeData]):
url = str(response.message)
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = url.split("/")[-1].split(".")[0]
try:
UUID(tool_file_id)
except ValueError:
raise ToolFileError(f"cannot extract tool file id from url {url}")
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)

View File

@ -1,4 +1,4 @@
class VariableOperatorNodeError(Exception):
class VariableOperatorNodeError(ValueError):
"""Base error type, don't use directly."""
pass

View File

@ -1,18 +1,5 @@
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
from dify_app import DifyApp
POSTGRES_INDEXES_NAMING_CONVENTION = {
"ix": "%(column_0_label)s_idx",
"uq": "%(table_name)s_%(column_0_name)s_key",
"ck": "%(table_name)s_%(constraint_name)s_check",
"fk": "%(table_name)s_%(column_0_name)s_fkey",
"pk": "%(table_name)s_pkey",
}
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
db = SQLAlchemy(metadata=metadata)
from models import db
def init_app(app: DifyApp):

View File

@ -3,4 +3,3 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from events import event_handlers # noqa: F401
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401

View File

@ -67,7 +67,9 @@ class AwsS3Storage(BaseStorage):
yield from response["Body"].iter_chunks()
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found")
raise FileNotFoundError("file not found")
elif "reached max retries" in str(ex):
raise ValueError("please do not request the same file too frequently")
else:
raise

View File

@ -116,8 +116,11 @@ def _build_from_local_file(
tenant_id: str,
transfer_method: FileTransferMethod,
) -> File:
upload_file_id = mapping.get("upload_file_id")
if not upload_file_id:
raise ValueError("Invalid upload file id")
stmt = select(UploadFile).where(
UploadFile.id == mapping.get("upload_file_id"),
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
@ -139,6 +142,7 @@ def _build_from_local_file(
remote_url=row.source_url,
related_id=mapping.get("upload_file_id"),
size=row.size,
storage_key=row.key,
)
@ -168,6 +172,7 @@ def _build_from_remote_url(
mime_type=mime_type,
extension=extension,
size=file_size,
storage_key="",
)
@ -220,6 +225,7 @@ def _build_from_tool_file(
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
storage_key=tool_file.file_key,
)

View File

@ -85,7 +85,7 @@ message_detail_fields = {
}
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer}
model_config_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
@ -166,6 +166,7 @@ conversation_with_summary_fields = {
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
"status_count": fields.Nested(status_count_fields),
}
conversation_with_summary_pagination_fields = {

View File

@ -82,13 +82,15 @@ workflow_run_detail_fields = {
}
retry_event_field = {
"elapsed_time": fields.Float,
"status": fields.String,
"inputs": fields.Raw(attribute="inputs"),
"process_data": fields.Raw(attribute="process_data"),
"outputs": fields.Raw(attribute="outputs"),
"metadata": fields.Raw(attribute="metadata"),
"llm_usage": fields.Raw(attribute="llm_usage"),
"error": fields.String,
"retry_index": fields.Integer,
"inputs": fields.Raw(attribute="inputs"),
"elapsed_time": fields.Float,
"execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
"status": fields.String,
"outputs": fields.Raw(attribute="outputs"),
}
@ -112,7 +114,6 @@ workflow_run_node_execution_fields = {
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"finished_at": TimestampField,
"retry_events": fields.List(fields.Nested(retry_event_field)),
}
workflow_run_node_execution_list_fields = {

View File

@ -13,7 +13,7 @@ from typing import Any, Optional, Union, cast
from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask_restful import fields # type: ignore
from flask_restful import fields
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator

View File

@ -27,7 +27,7 @@ def parse_json_markdown(json_string: str) -> dict:
extracted_content = json_string[start_index:end_index].strip()
parsed = json.loads(extracted_content)
else:
raise Exception("Could not find JSON block in the output.")
raise ValueError("could not find json block in the output.")
return parsed
@ -36,10 +36,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserError(f"Got invalid JSON object. Error: {e}")
raise OutputParserError(f"got invalid json object. error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserError(
f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}"
f"got invalid return object. expected key `{key}` to be present, but got {json_obj}"
)
return json_obj

View File

@ -1,33 +0,0 @@
"""add retry_index field to node-execution model
Revision ID: 348cb0a93d53
Revises: cf8f4fc45278
Create Date: 2024-12-16 01:23:13.093432
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '348cb0a93d53'
down_revision = 'cf8f4fc45278'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.drop_column('retry_index')
# ### end Alembic commands ###

View File

@ -0,0 +1,39 @@
"""remove unused tool_providers
Revision ID: 11b07f66c737
Revises: cf8f4fc45278
Create Date: 2024-12-19 17:46:25.780116
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '11b07f66c737'
down_revision = 'cf8f4fc45278'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tool_providers')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_providers',
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
# ### end Alembic commands ###

View File

@ -0,0 +1,37 @@
"""add retry_index field to node-execution model
Revision ID: e1944c35e15e
Revises: 11b07f66c737
Create Date: 2024-12-20 06:28:30.287197
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'e1944c35e15e'
down_revision = '11b07f66c737'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
# We don't need these fields anymore, but this file is already merged into the main branch,
# so we need to keep this file for the sake of history, and this change will be reverted in the next migration.
# with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
# batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
pass
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
# with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
# batch_op.drop_column('retry_index')
pass
# ### end Alembic commands ###

View File

@ -0,0 +1,34 @@
"""remove workflow_node_executions.retry_index if exists
Revision ID: d7999dfa4aae
Revises: e1944c35e15e
Create Date: 2024-12-23 11:54:15.344543
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy import inspect
# revision identifiers, used by Alembic.
revision = 'd7999dfa4aae'
down_revision = 'e1944c35e15e'
branch_labels = None
depends_on = None
def upgrade():
# Check if column exists before attempting to remove it
conn = op.get_bind()
inspector = inspect(conn)
has_column = 'retry_index' in [col['name'] for col in inspector.get_columns('workflow_node_executions')]
if has_column:
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.drop_column('retry_index')
def downgrade():
# No downgrade needed as we don't want to restore the column
pass

View File

@ -1,53 +1,187 @@
from .account import Account, AccountIntegrate, InvitationCode, Tenant
from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
from .account import (
Account,
AccountIntegrate,
AccountStatus,
InvitationCode,
Tenant,
TenantAccountJoin,
TenantAccountJoinRole,
TenantAccountRole,
TenantStatus,
)
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from .dataset import (
AppDatasetJoin,
Dataset,
DatasetCollectionBinding,
DatasetKeywordTable,
DatasetPermission,
DatasetPermissionEnum,
DatasetProcessRule,
DatasetQuery,
Document,
DocumentSegment,
Embedding,
ExternalKnowledgeApis,
ExternalKnowledgeBindings,
TidbAuthBinding,
Whitelist,
)
from .engine import db
from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom
from .model import (
ApiRequest,
ApiToken,
App,
AppAnnotationHitHistory,
AppAnnotationSetting,
AppMode,
AppModelConfig,
Conversation,
DatasetRetrieverResource,
DifySetup,
EndUser,
IconType,
InstalledApp,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
OperationLog,
RecommendedApp,
Site,
Tag,
TagBinding,
TraceAppConfig,
UploadFile,
)
from .source import DataSourceOauthBinding
from .tools import ToolFile
from .provider import (
LoadBalancingModelConfig,
Provider,
ProviderModel,
ProviderModelSetting,
ProviderOrder,
ProviderQuotaType,
ProviderType,
TenantDefaultModel,
TenantPreferredModelProvider,
)
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from .task import CeleryTask, CeleryTaskSet
from .tools import (
ApiToolProvider,
BuiltinToolProvider,
PublishedAppTool,
ToolConversationVariables,
ToolFile,
ToolLabelBinding,
ToolModelInvoke,
WorkflowToolProvider,
)
from .web import PinnedConversation, SavedMessage
from .workflow import (
ConversationVariable,
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowRunStatus,
WorkflowType,
)
__all__ = [
"APIBasedExtension",
"APIBasedExtensionPoint",
"Account",
"AccountIntegrate",
"AccountStatus",
"ApiRequest",
"ApiToken",
"ApiToolProvider", # Added
"App",
"AppAnnotationHitHistory",
"AppAnnotationSetting",
"AppDatasetJoin",
"AppMode",
"AppModelConfig",
"BuiltinToolProvider", # Added
"CeleryTask",
"CeleryTaskSet",
"Conversation",
"ConversationVariable",
"CreatedByRole",
"DataSourceApiKeyAuthBinding",
"DataSourceOauthBinding",
"Dataset",
"DatasetCollectionBinding",
"DatasetKeywordTable",
"DatasetPermission",
"DatasetPermissionEnum",
"DatasetProcessRule",
"DatasetQuery",
"DatasetRetrieverResource",
"DifySetup",
"Document",
"DocumentSegment",
"Embedding",
"EndUser",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
"InstalledApp",
"InvitationCode",
"LoadBalancingModelConfig",
"Message",
"MessageAgentThought",
"MessageAnnotation",
"MessageChain",
"MessageFeedback",
"MessageFile",
"OperationLog",
"PinnedConversation",
"Provider",
"ProviderModel",
"ProviderModelSetting",
"ProviderOrder",
"ProviderQuotaType",
"ProviderType",
"PublishedAppTool",
"RecommendedApp",
"SavedMessage",
"Site",
"Tag",
"TagBinding",
"Tenant",
"TenantAccountJoin",
"TenantAccountJoinRole",
"TenantAccountRole",
"TenantDefaultModel",
"TenantPreferredModelProvider",
"TenantStatus",
"TidbAuthBinding",
"ToolConversationVariables",
"ToolFile",
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
"UploadFile",
"UserFrom",
"Whitelist",
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowNodeExecution",
"WorkflowNodeExecutionStatus",
"WorkflowNodeExecutionTriggeredFrom",
"WorkflowRun",
"WorkflowRunStatus",
"WorkflowRunTriggeredFrom",
"WorkflowToolProvider",
"WorkflowType",
"db",
]

View File

@ -2,9 +2,9 @@ import enum
import json
from flask_login import UserMixin
from sqlalchemy import func
from extensions.ext_database import db
from .engine import db
from .types import StringUUID
@ -31,11 +31,11 @@ class Account(UserMixin, db.Model):
timezone = db.Column(db.String(255))
last_login_at = db.Column(db.DateTime)
last_login_ip = db.Column(db.String(255))
last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying"))
initialized_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def is_password_set(self):
@ -99,11 +99,6 @@ class Account(UserMixin, db.Model):
return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none()
return None
def get_integrates(self) -> list[db.Model]:
ai = db.Model
return db.session.query(ai).filter(ai.account_id == self.id).all()
# check current_user.current_tenant.current_role in ['admin', 'owner']
@property
def is_admin_or_owner(self):
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
@ -188,8 +183,8 @@ class Tenant(db.Model):
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
custom_config = db.Column(db.Text)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def get_accounts(self) -> list[Account]:
return (
@ -229,8 +224,8 @@ class TenantAccountJoin(db.Model):
current = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
role = db.Column(db.String(16), nullable=False, server_default="normal")
invited_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AccountIntegrate(db.Model):
@ -246,8 +241,8 @@ class AccountIntegrate(db.Model):
provider = db.Column(db.String(16), nullable=False)
open_id = db.Column(db.String(255), nullable=False)
encrypted_token = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class InvitationCode(db.Model):
@ -266,4 +261,4 @@ class InvitationCode(db.Model):
used_by_tenant_id = db.Column(StringUUID)
used_by_account_id = db.Column(StringUUID)
deprecated_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -1,7 +1,8 @@
import enum
from extensions.ext_database import db
from sqlalchemy import func
from .engine import db
from .types import StringUUID
@ -24,4 +25,4 @@ class APIBasedExtension(db.Model):
name = db.Column(db.String(255), nullable=False)
api_endpoint = db.Column(db.String(255), nullable=False)
api_key = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -15,10 +15,10 @@ from sqlalchemy.dialects.postgresql import JSONB
from configs import dify_config
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage
from .account import Account
from .engine import db
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
@ -50,9 +50,9 @@ class Dataset(db.Model):
indexing_technique = db.Column(db.String(255), nullable=True)
index_struct = db.Column(db.Text, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(StringUUID, nullable=True)
@ -212,7 +212,7 @@ class DatasetProcessRule(db.Model):
mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
rules = db.Column(db.Text, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
MODES = ["automatic", "custom"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
@ -264,7 +264,7 @@ class Document(db.Model):
created_from = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_api_request_id = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
# start processing
processing_started_at = db.Column(db.DateTime, nullable=True)
@ -303,7 +303,7 @@ class Document(db.Model):
archived_reason = db.Column(db.String(255), nullable=True)
archived_by = db.Column(StringUUID, nullable=True)
archived_at = db.Column(db.DateTime, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
doc_type = db.Column(db.String(40), nullable=True)
doc_metadata = db.Column(db.JSON, nullable=True)
doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
@ -527,9 +527,9 @@ class DocumentSegment(db.Model):
disabled_by = db.Column(StringUUID, nullable=True)
status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at = db.Column(db.DateTime, nullable=True)
completed_at = db.Column(db.DateTime, nullable=True)
error = db.Column(db.Text, nullable=True)
@ -697,7 +697,7 @@ class Embedding(db.Model):
)
hash = db.Column(db.String(64), nullable=False)
embedding = db.Column(db.LargeBinary, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
def set_embedding(self, embedding_data: list[float]):
@ -719,7 +719,7 @@ class DatasetCollectionBinding(db.Model):
model_name = db.Column(db.String(255), nullable=False)
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TidbAuthBinding(db.Model):
@ -739,7 +739,7 @@ class TidbAuthBinding(db.Model):
status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
account = db.Column(db.String(255), nullable=False)
password = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Whitelist(db.Model):
@ -751,7 +751,7 @@ class Whitelist(db.Model):
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=True)
category = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetPermission(db.Model):
@ -768,7 +768,7 @@ class DatasetPermission(db.Model):
account_id = db.Column(StringUUID, nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ExternalKnowledgeApis(db.Model):
@ -785,9 +785,9 @@ class ExternalKnowledgeApis(db.Model):
tenant_id = db.Column(StringUUID, nullable=False)
settings = db.Column(db.Text, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def to_dict(self):
return {
@ -840,6 +840,6 @@ class ExternalKnowledgeBindings(db.Model):
dataset_id = db.Column(StringUUID, nullable=False)
external_knowledge_id = db.Column(db.Text, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

Some files were not shown because too many files have changed in this diff Show More