Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly
2024-11-07 17:06:29 +08:00
46 changed files with 1342 additions and 1425 deletions

View File

@ -12,6 +12,10 @@ SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "")
SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
SSRF_DEFAULT_TIME_OUT = float(os.getenv("SSRF_DEFAULT_TIME_OUT", "5"))
SSRF_DEFAULT_CONNECT_TIME_OUT = float(os.getenv("SSRF_DEFAULT_CONNECT_TIME_OUT", "5"))
SSRF_DEFAULT_READ_TIME_OUT = float(os.getenv("SSRF_DEFAULT_READ_TIME_OUT", "5"))
SSRF_DEFAULT_WRITE_TIME_OUT = float(os.getenv("SSRF_DEFAULT_WRITE_TIME_OUT", "5"))
proxy_mounts = (
{
@ -32,6 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "follow_redirects" not in kwargs:
kwargs["follow_redirects"] = allow_redirects
if "timeout" not in kwargs:
kwargs["timeout"] = httpx.Timeout(
SSRF_DEFAULT_TIME_OUT,
connect=SSRF_DEFAULT_CONNECT_TIME_OUT,
read=SSRF_DEFAULT_READ_TIME_OUT,
write=SSRF_DEFAULT_WRITE_TIME_OUT,
)
retries = 0
while retries <= max_retries:
try:

View File

@ -1,61 +0,0 @@
model: anthropic.claude-3-5-haiku-20241022-v1:0
label:
en_US: Claude 3.5 Haiku
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 8192
min: 1
max: 8192
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.001'
output: '0.005'
unit: '0.001'
currency: USD

View File

@ -1,60 +0,0 @@
model: anthropic.claude-3-5-sonnet-20241022-v2:0
label:
en_US: Claude 3.5 Sonnet V2
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -1,60 +0,0 @@
model: eu.anthropic.claude-3-5-sonnet-20241022-v2:0
label:
en_US: Claude 3.5 Sonnet V2(EU.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -1,61 +0,0 @@
model: us.anthropic.claude-3-5-haiku-20241022-v1:0
label:
en_US: Claude 3.5 Haiku(US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.001'
output: '0.005'
unit: '0.001'
currency: USD

View File

@ -1,60 +0,0 @@
model: us.anthropic.claude-3-5-sonnet-20241022-v2:0
label:
en_US: Claude 3.5 Sonnet V2(US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,22 @@
class IterationNodeError(ValueError):
"""Base class for iteration node errors."""
class IteratorVariableNotFoundError(IterationNodeError):
"""Raised when the iterator variable is not found."""
class InvalidIteratorValueError(IterationNodeError):
"""Raised when the iterator value is invalid."""
class StartNodeIdNotFoundError(IterationNodeError):
"""Raised when the start node ID is not found."""
class IterationGraphNotFoundError(IterationNodeError):
"""Raised when the iteration graph is not found."""
class IterationIndexNotFoundError(IterationNodeError):
"""Raised when the iteration index is not found."""

View File

@ -38,6 +38,15 @@ from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from models.workflow import WorkflowNodeExecutionStatus
from .exc import (
InvalidIteratorValueError,
IterationGraphNotFoundError,
IterationIndexNotFoundError,
IterationNodeError,
IteratorVariableNotFoundError,
StartNodeIdNotFoundError,
)
if TYPE_CHECKING:
from core.workflow.graph_engine.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@ -69,7 +78,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not iterator_list_segment:
raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
if len(iterator_list_segment.value) == 0:
yield RunCompletedEvent(
@ -83,14 +92,14 @@ class IterationNode(BaseNode[IterationNodeData]):
iterator_list_value = iterator_list_segment.to_object()
if not isinstance(iterator_list_value, list):
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in iteration {self.node_id} not found")
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self.node_data.start_node_id
@ -98,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
if not iteration_graph:
raise ValueError("iteration graph not found")
raise IterationGraphNotFoundError("iteration graph not found")
variable_pool = self.graph_runtime_state.variable_pool
@ -222,9 +231,9 @@ class IterationNode(BaseNode[IterationNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
)
)
except Exception as e:
except IterationNodeError as e:
# iteration run failed
logger.exception("Iteration run failed")
logger.warning("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -272,7 +281,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
if not iteration_graph:
raise ValueError("iteration graph not found")
raise IterationGraphNotFoundError("iteration graph not found")
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
@ -357,7 +366,7 @@ class IterationNode(BaseNode[IterationNodeData]):
next_index = int(current_index) + 1
if current_index is None:
raise ValueError(f"iteration {self.node_id} current index not found")
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
@ -484,8 +493,8 @@ class IterationNode(BaseNode[IterationNodeData]):
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
)
except Exception as e:
logger.exception(f"Iteration run failed:{str(e)}")
except IterationNodeError as e:
logger.warning(f"Iteration run failed:{str(e)}")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,

View File

@ -0,0 +1,18 @@
class KnowledgeRetrievalNodeError(ValueError):
"""Base class for KnowledgeRetrievalNode errors."""
class ModelNotExistError(KnowledgeRetrievalNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeRetrievalNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
"""Raised when the model provider quota is exceeded."""

View File

@ -8,7 +8,6 @@ from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -18,11 +17,19 @@ from core.variables import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
from .entities import KnowledgeRetrievalNodeData
from .exc import (
KnowledgeRetrievalNodeError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
ModelNotSupportedError,
ModelQuotaExceededError,
)
logger = logging.getLogger(__name__)
default_retrieval_model = {
@ -61,8 +68,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
)
except Exception as e:
logger.exception("Error when running knowledge retrieval node")
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node")
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
@ -295,14 +302,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data.single_retrieval_config.model.completion_params
@ -314,12 +321,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
# get model mode
model_mode = node_data.single_retrieval_config.model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
raise ModelNotExistError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,

View File

@ -0,0 +1,6 @@
class QuestionClassifierNodeError(ValueError):
"""Base class for QuestionClassifierNode errors."""
class InvalidModelTypeError(QuestionClassifierNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.llm_generator.output_parser.errors import OutputParserError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
@ -24,6 +25,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models.workflow import WorkflowNodeExecutionStatus
from .entities import QuestionClassifierNodeData
from .exc import InvalidModelTypeError
from .template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
@ -124,7 +126,7 @@ class QuestionClassifierNode(LLMNode):
category_name = classes_map[category_id_result]
category_id = category_id_result
except Exception:
except OutputParserError:
logging.error(f"Failed to parse result text: {result_text}")
try:
process_data = {
@ -309,4 +311,4 @@ class QuestionClassifierNode(LLMNode):
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

View File

@ -0,0 +1,16 @@
class ToolNodeError(ValueError):
"""Base exception for tool node errors."""
pass
class ToolParameterError(ToolNodeError):
"""Exception raised for errors in tool parameters."""
pass
class ToolFileError(ToolNodeError):
"""Exception raised for errors related to tool files."""
pass

View File

@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.models import File, FileTransferMethod, FileType
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
@ -19,12 +19,18 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.tools import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ToolNodeData
from .exc import (
ToolFileError,
ToolNodeError,
ToolParameterError,
)
class ToolNode(BaseNode[ToolNodeData]):
"""
@ -49,7 +55,7 @@ class ToolNode(BaseNode[ToolNodeData]):
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
except Exception as e:
except ToolNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -58,7 +64,6 @@ class ToolNode(BaseNode[ToolNodeData]):
error=f"Failed to get tool runtime: {str(e)}",
)
)
return
# get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
@ -85,7 +90,7 @@ class ToolNode(BaseNode[ToolNodeData]):
app_id=self.app_id,
# TODO: conversation id and message id
)
except Exception as e:
except ToolNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -94,7 +99,6 @@ class ToolNode(BaseNode[ToolNodeData]):
error=f"Failed to invoke tool: {str(e)}",
)
)
return
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
@ -131,14 +135,13 @@ class ToolNode(BaseNode[ToolNodeData]):
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ValueError(f"variable {tool_input.value} not exists")
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ValueError(f"unknown tool input type '{tool_input.type}'")
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
result[parameter_name] = parameter_value
return result
@ -187,7 +190,7 @@ class ToolNode(BaseNode[ToolNodeData]):
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
files.append(
File(
@ -212,8 +215,7 @@ class ToolNode(BaseNode[ToolNodeData]):
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
raise ToolFileError(f"tool file {tool_file_id} not exists")
files.append(
File(
tenant_id=self.tenant_id,