Merge branch 'main' into feat/support-extractor-tools

This commit is contained in:
jyong
2024-11-05 16:31:19 +08:00
117 changed files with 3103 additions and 935 deletions

View File

@ -1,8 +1,7 @@
from collections.abc import Mapping
from typing import Any
from core.file.models import FileExtraConfig
from models import FileUploadConfig
from core.file import FileExtraConfig
class FileUploadConfigManager:
@ -43,6 +42,6 @@ class FileUploadConfigManager:
if not config.get("file_upload"):
config["file_upload"] = {}
else:
FileUploadConfig.model_validate(config["file_upload"])
FileExtraConfig.model_validate(config["file_upload"])
return config, ["file_upload"]

View File

@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(

View File

@ -22,7 +22,10 @@ class BaseAppGenerator:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
user_inputs = {
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
for var in variables
}
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
# Convert files in inputs to File
entity_dictionary = {item.variable: item for item in app_config.variables}
@ -74,57 +77,66 @@ class BaseAppGenerator:
return user_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
user_input_value = inputs.get(var.variable)
def _validate_inputs(
self,
*,
variable_entity: "VariableEntity",
value: Any,
):
if value is None:
if variable_entity.required:
raise ValueError(f"{variable_entity.variable} is required in input form")
return value
if not user_input_value:
if var.required:
raise ValueError(f"{var.variable} is required in input form")
else:
return None
if var.type in {
if variable_entity.type in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
} and not isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
} and not isinstance(value, str):
raise ValueError(
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
)
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if "." in user_input_value:
return float(user_input_value)
if "." in value:
return float(value)
else:
return int(user_input_value)
return int(value)
except ValueError:
raise ValueError(f"{var.variable} in input form must be a valid number")
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
match var.type:
match variable_entity.type:
case VariableEntityType.SELECT:
if user_input_value not in var.options:
raise ValueError(f"{var.variable} in input form must be one of the following: {var.options}")
if value not in variable_entity.options:
raise ValueError(
f"{variable_entity.variable} in input form must be one of the following: "
f"{variable_entity.options}"
)
case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH:
if var.max_length and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
if variable_entity.max_length and len(value) > variable_entity.max_length:
raise ValueError(
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} "
"characters"
)
case VariableEntityType.FILE:
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
raise ValueError(f"{var.variable} in input form must be a file")
if not isinstance(value, dict) and not isinstance(value, File):
raise ValueError(f"{variable_entity.variable} in input form must be a file")
case VariableEntityType.FILE_LIST:
# if number of files exceeds the limit, raise ValueError
if not (
isinstance(user_input_value, list)
and (
all(isinstance(item, dict) for item in user_input_value)
or all(isinstance(item, File) for item in user_input_value)
)
isinstance(value, list)
and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value))
):
raise ValueError(f"{var.variable} in input form must be a list of files")
raise ValueError(f"{variable_entity.variable} in input form must be a list of files")
if var.max_length and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} files")
if variable_entity.max_length and len(value) > variable_entity.max_length:
raise ValueError(
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
)
return user_input_value
return value
def _sanitize_value(self, value: Any) -> Any:
if isinstance(value, str):

View File

@ -16,6 +16,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(

View File

@ -9,6 +9,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import (
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner):
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
parallel_mode_run_id=event.parallel_mode_run_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner):
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner):
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
parallel_mode_run_id=event.parallel_mode_run_id,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):

View File

@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent):
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
class QueueNodeSucceededEvent(AppQueueEvent):
@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
error: Optional[str] = None
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str

View File

@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
parallel_run_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str

View File

@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
@ -251,6 +253,12 @@ class WorkflowCycleManage:
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
session.add(workflow_node_execution)
@ -305,7 +313,9 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
@ -318,16 +328,19 @@ class WorkflowCycleManage:
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
WorkflowNodeExecution.execution_metadata: execution_metadata,
}
)
@ -342,6 +355,7 @@ class WorkflowCycleManage:
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
@ -448,6 +462,7 @@ class WorkflowCycleManage:
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
parallel_run_id=event.parallel_mode_run_id,
),
)
@ -464,7 +479,7 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
@ -608,6 +623,7 @@ class WorkflowCycleManage:
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
),
)
@ -633,7 +649,9 @@ class WorkflowCycleManage:
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,

View File

@ -598,7 +598,7 @@ class IndexingRunner:
rules = DatasetProcessRule.AUTOMATIC_RULES
else:
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
document_text = CleanProcessor.clean(text, rules)
document_text = CleanProcessor.clean(text, {"rules": rules})
return document_text

View File

@ -1,3 +1,4 @@
- claude-3-5-haiku-20241022
- claude-3-5-sonnet-20241022
- claude-3-5-sonnet-20240620
- claude-3-haiku-20240307

View File

@ -0,0 +1,39 @@
model: claude-3-5-haiku-20241022
label:
en_US: claude-3-5-haiku-20241022
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
pricing:
input: '1.00'
output: '5.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,61 @@
model: anthropic.claude-3-5-haiku-20241022-v1:0
label:
en_US: Claude 3.5 Haiku
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 8192
min: 1
max: 8192
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.001'
output: '0.005'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,61 @@
model: us.anthropic.claude-3-5-haiku-20241022-v1:0
label:
en_US: Claude 3.5 Haiku(US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.001'
output: '0.005'
unit: '0.001'
currency: USD

View File

@ -1,6 +1,7 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
import requests
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
@ -16,8 +17,18 @@ class GiteeAIProvider(ModelProvider):
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials)
api_key = credentials.get("api_key")
if not api_key:
raise CredentialsValidateFailedError("Credentials validation failed: api_key not given")
# send a get request to validate the credentials
headers = {"Authorization": f"Bearer {api_key}"}
response = requests.get("https://ai.gitee.com/api/base/account/me", headers=headers, timeout=(10, 300))
if response.status_code != 200:
raise CredentialsValidateFailedError(
f"Credentials validation failed with status code {response.status_code}"
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg>

After

Width:  |  Height:  |  Size: 356 B

View File

@ -0,0 +1,63 @@
model: grok-beta
label:
en_US: Grok beta
model_type: llm
features:
- multi-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 2.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: 0
max: 2.0
precision: 1
required: false
help:
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -0,0 +1,37 @@
from collections.abc import Generator
from typing import Optional, Union
from yarl import URL
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
credentials["mode"] = LLMMode.CHAT.value
credentials["function_calling_type"] = "tool_call"

View File

@ -0,0 +1,25 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class XAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -0,0 +1,38 @@
provider: x
label:
en_US: xAI
description:
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
icon_small:
en_US: x-ai-logo.svg
icon_large:
en_US: x-ai-logo.svg
help:
title:
en_US: Get your token from xAI
zh_Hans: 从 xAI 获取 token
url:
en_US: https://x.ai/api
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: endpoint_url
label:
en_US: API Base
type: text-input
required: false
default: https://api.x.ai/v1
placeholder:
zh_Hans: 在此输入您的 API Base
en_US: Enter your API Base

View File

@ -14,6 +14,7 @@ import requests
from docx import Document as DocxDocument
from configs import dify_config
from core.helper import ssrf_proxy
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
@ -86,7 +87,7 @@ class WordExtractor(BaseExtractor):
image_count += 1
if rel.is_external:
url = rel.reltype
response = requests.get(url, stream=True)
response = ssrf_proxy.get(url, stream=True)
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
file_uuid = str(uuid.uuid4())

View File

@ -4,7 +4,7 @@ from hmac import new as hmac_new
from json import loads as json_loads
from threading import Lock
from time import sleep, time
from typing import Any, Optional
from typing import Any
from httpx import get, post
from requests import get as requests_get
@ -15,27 +15,27 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter,
from core.tools.tool.builtin_tool import BuiltinTool
class AIPPTGenerateTool(BuiltinTool):
class AIPPTGenerateToolAdapter:
"""
A tool for generating a ppt
"""
_api_base_url = URL("https://co.aippt.cn/api")
_api_token_cache = {}
_api_token_cache_lock: Optional[Lock] = None
_style_cache = {}
_style_cache_lock: Optional[Lock] = None
_api_token_cache_lock = Lock()
_style_cache_lock = Lock()
_task = {}
_task_type_map = {
"auto": 1,
"markdown": 7,
}
_tool: BuiltinTool
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._api_token_cache_lock = Lock()
self._style_cache_lock = Lock()
def __init__(self, tool: BuiltinTool = None):
self._tool = tool
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
@ -51,11 +51,11 @@ class AIPPTGenerateTool(BuiltinTool):
"""
title = tool_parameters.get("title", "")
if not title:
return self.create_text_message("Please provide a title for the ppt")
return self._tool.create_text_message("Please provide a title for the ppt")
model = tool_parameters.get("model", "aippt")
if not model:
return self.create_text_message("Please provide a model for the ppt")
return self._tool.create_text_message("Please provide a model for the ppt")
outline = tool_parameters.get("outline", "")
@ -68,8 +68,8 @@ class AIPPTGenerateTool(BuiltinTool):
)
# get suit
color = tool_parameters.get("color")
style = tool_parameters.get("style")
color: str = tool_parameters.get("color")
style: str = tool_parameters.get("style")
if color == "__default__":
color_id = ""
@ -93,9 +93,9 @@ class AIPPTGenerateTool(BuiltinTool):
# generate ppt
_, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id)
return self.create_text_message(
return self._tool.create_text_message(
"""the ppt has been created successfully,"""
f"""the ppt url is {ppt_url}"""
f"""the ppt url is {ppt_url} ."""
"""please give the ppt url to user and direct user to download it."""
)
@ -111,8 +111,8 @@ class AIPPTGenerateTool(BuiltinTool):
"""
headers = {
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
}
response = post(
str(self._api_base_url / "ai" / "chat" / "v2" / "task"),
@ -139,8 +139,8 @@ class AIPPTGenerateTool(BuiltinTool):
headers = {
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
}
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
@ -183,8 +183,8 @@ class AIPPTGenerateTool(BuiltinTool):
headers = {
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
}
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
@ -236,14 +236,15 @@ class AIPPTGenerateTool(BuiltinTool):
"""
headers = {
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
}
response = post(
str(self._api_base_url / "design" / "v2" / "save"),
headers=headers,
data={"task_id": task_id, "template_id": suit_id},
timeout=(10, 60),
)
if response.status_code != 200:
@ -350,11 +351,13 @@ class AIPPTGenerateTool(BuiltinTool):
return token
@classmethod
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
@staticmethod
def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str:
return b64encode(
hmac_new(
key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1
key=secret_key.encode("utf-8"),
msg=f"GET@/api/grant/token/@{timestamp}".encode(),
digestmod=sha1,
).digest()
).decode("utf-8")
@ -419,10 +422,12 @@ class AIPPTGenerateTool(BuiltinTool):
:param credentials: the credentials
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
"""
if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"):
if not self._tool.runtime.credentials.get("aippt_access_key") or not self._tool.runtime.credentials.get(
"aippt_secret_key"
):
raise Exception("Please provide aippt credentials")
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
return self._get_styles(credentials=self._tool.runtime.credentials, user_id=user_id)
def _get_suit(self, style_id: int, colour_id: int) -> int:
"""
@ -430,8 +435,8 @@ class AIPPTGenerateTool(BuiltinTool):
"""
headers = {
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"),
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id="__dify_system__"),
}
response = get(
str(self._api_base_url / "template_component" / "suit" / "search"),
@ -496,3 +501,18 @@ class AIPPTGenerateTool(BuiltinTool):
],
),
]
class AIPPTGenerateTool(BuiltinTool):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
def get_runtime_parameters(self) -> list[ToolParameter]:
return AIPPTGenerateToolAdapter(self).get_runtime_parameters()
@classmethod
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
return AIPPTGenerateToolAdapter()._get_api_token(credentials, user_id)

View File

@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum):
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
class NodeRunResult(BaseModel):

View File

@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent):
class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
"""predecessor node id"""
@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
###########################################
# Parallel Branch Events
###########################################
@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent):
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
class IterationRunStartedEvent(BaseIterationEvent):

View File

@ -4,6 +4,7 @@ import time
import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from typing import Any, Optional
from flask import Flask, current_app
@ -724,6 +725,16 @@ class GraphEngine:
"""
return time.perf_counter() - start_at > max_execution_time
def create_copy(self):
"""
create a graph engine copy
:return: with a new variable pool instance of graph engine
"""
new_instance = copy(self)
new_instance.graph_runtime_state = copy(self.graph_runtime_state)
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
return new_instance
class GraphRunFailedError(Exception):
def __init__(self, error: str):

View File

@ -12,6 +12,12 @@ from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .exc import (
CodeNodeError,
DepthLimitError,
OutputValidationError,
)
class CodeNode(BaseNode[CodeNodeData]):
_node_data_cls = CodeNodeData
@ -60,7 +66,7 @@ class CodeNode(BaseNode[CodeNodeData]):
# Transform result
result = self._transform_result(result, self.node_data.outputs)
except (CodeExecutionError, ValueError) as e:
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
@ -76,10 +82,10 @@ class CodeNode(BaseNode[CodeNodeData]):
if value is None:
return None
else:
raise ValueError(f"Output variable `{variable}` must be a string")
raise OutputValidationError(f"Output variable `{variable}` must be a string")
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise ValueError(
raise OutputValidationError(
f"The length of output variable `{variable}` must be"
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
)
@ -97,10 +103,10 @@ class CodeNode(BaseNode[CodeNodeData]):
if value is None:
return None
else:
raise ValueError(f"Output variable `{variable}` must be a number")
raise OutputValidationError(f"Output variable `{variable}` must be a number")
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise ValueError(
raise OutputValidationError(
f"Output variable `{variable}` is out of range,"
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
)
@ -108,7 +114,7 @@ class CodeNode(BaseNode[CodeNodeData]):
if isinstance(value, float):
# raise error if precision is too high
if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION:
raise ValueError(
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
)
@ -125,7 +131,7 @@ class CodeNode(BaseNode[CodeNodeData]):
:return:
"""
if depth > dify_config.CODE_MAX_DEPTH:
raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result = {}
if output_schema is None:
@ -177,14 +183,14 @@ class CodeNode(BaseNode[CodeNodeData]):
depth=depth + 1,
)
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}.{output_name} is not a valid array."
f" make sure all elements are of the same type."
)
elif output_value is None:
pass
else:
raise ValueError(f"Output {prefix}.{output_name} is not a valid type.")
raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.")
return result
@ -192,7 +198,7 @@ class CodeNode(BaseNode[CodeNodeData]):
for output_name, output_config in output_schema.items():
dot = "." if prefix else ""
if output_name not in result:
raise ValueError(f"Output {prefix}{dot}{output_name} is missing.")
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
if output_config.type == "object":
# check if output is object
@ -200,7 +206,7 @@ class CodeNode(BaseNode[CodeNodeData]):
if isinstance(result.get(output_name), type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an object,"
f" got {type(result.get(output_name))} instead."
)
@ -228,13 +234,13 @@ class CodeNode(BaseNode[CodeNodeData]):
if isinstance(result[output_name], type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
raise ValueError(
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
)
@ -249,13 +255,13 @@ class CodeNode(BaseNode[CodeNodeData]):
if isinstance(result[output_name], type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
raise ValueError(
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
)
@ -270,13 +276,13 @@ class CodeNode(BaseNode[CodeNodeData]):
if isinstance(result[output_name], type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
raise ValueError(
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
)
@ -286,7 +292,7 @@ class CodeNode(BaseNode[CodeNodeData]):
if value is None:
pass
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}{dot}{output_name}[{i}] is not an object,"
f" got {type(value)} instead at index {i}."
)
@ -303,13 +309,13 @@ class CodeNode(BaseNode[CodeNodeData]):
for i, value in enumerate(result[output_name])
]
else:
raise ValueError(f"Output type {output_config.type} is not supported.")
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
parameters_validated[output_name] = True
# check if all output parameters are validated
if len(parameters_validated) != len(result):
raise ValueError("Not all output parameters are validated.")
raise CodeNodeError("Not all output parameters are validated.")
return transformed_result

View File

@ -0,0 +1,16 @@
class CodeNodeError(ValueError):
"""Base class for code node errors."""
pass
class OutputValidationError(CodeNodeError):
"""Raised when there is an output validation error."""
pass
class DepthLimitError(CodeNodeError):
"""Raised when the depth limit is reached."""
pass

View File

@ -198,10 +198,8 @@ def _download_file_content(file: File) -> bytes:
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return response.content
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
return file_manager.download(file)
else:
raise ValueError(f"Unsupported transfer method: {file.transfer_method}")
return file_manager.download(file)
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e

View File

@ -0,0 +1,18 @@
class HttpRequestNodeError(ValueError):
"""Custom error for HTTP request node."""
class AuthorizationConfigError(HttpRequestNodeError):
"""Raised when authorization config is missing or invalid."""
class FileFetchError(HttpRequestNodeError):
"""Raised when a file cannot be fetched."""
class InvalidHttpMethodError(HttpRequestNodeError):
"""Raised when an invalid HTTP method is used."""
class ResponseSizeError(HttpRequestNodeError):
"""Raised when the response size exceeds the allowed threshold."""

View File

@ -18,6 +18,12 @@ from .entities import (
HttpRequestNodeTimeout,
Response,
)
from .exc import (
AuthorizationConfigError,
FileFetchError,
InvalidHttpMethodError,
ResponseSizeError,
)
BODY_TYPE_TO_CONTENT_TYPE = {
"json": "application/json",
@ -51,7 +57,7 @@ class Executor:
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
if node_data.authorization.config is None:
raise ValueError("authorization config is required")
raise AuthorizationConfigError("authorization config is required")
node_data.authorization.config.api_key = variable_pool.convert_template(
node_data.authorization.config.api_key
).text
@ -82,8 +88,10 @@ class Executor:
self.url = self.variable_pool.convert_template(self.node_data.url).text
def _init_params(self):
params = self.variable_pool.convert_template(self.node_data.params).text
self.params = _plain_text_to_dict(params)
params = _plain_text_to_dict(self.node_data.params)
for key in params:
params[key] = self.variable_pool.convert_template(params[key]).text
self.params = params
def _init_headers(self):
headers = self.variable_pool.convert_template(self.node_data.headers).text
@ -116,7 +124,7 @@ class Executor:
file_selector = data[0].file
file_variable = self.variable_pool.get_file(file_selector)
if file_variable is None:
raise ValueError(f"cannot fetch file with selector {file_selector}")
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
self.content = file_manager.download(file)
case "x-www-form-urlencoded":
@ -155,12 +163,12 @@ class Executor:
headers = deepcopy(self.headers) or {}
if self.auth.type == "api-key":
if self.auth.config is None:
raise ValueError("self.authorization config is required")
raise AuthorizationConfigError("self.authorization config is required")
if authorization.config is None:
raise ValueError("authorization config is required")
raise AuthorizationConfigError("authorization config is required")
if self.auth.config.api_key is None:
raise ValueError("api_key is required")
raise AuthorizationConfigError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
@ -183,7 +191,7 @@ class Executor:
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
)
if executor_response.size > threshold_size:
raise ValueError(
raise ResponseSizeError(
f'{"File" if executor_response.is_file else "Text"} size is too large,'
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
f' but current size is {executor_response.readable_size}.'
@ -196,7 +204,7 @@ class Executor:
do http request depending on api bundle
"""
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
raise ValueError(f"Invalid http method {self.method}")
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
"url": self.url,

View File

@ -20,6 +20,7 @@ from .entities import (
HttpRequestNodeTimeout,
Response,
)
from .exc import HttpRequestNodeError
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
@ -77,7 +78,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
"request": http_executor.to_log(),
},
)
except Exception as e:
except HttpRequestNodeError as e:
logger.warning(f"http request node {self.node_id} failed to run: {e}")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,

View File

@ -1,3 +1,4 @@
from enum import Enum
from typing import Any, Optional
from pydantic import Field
@ -5,6 +6,12 @@ from pydantic import Field
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
class ErrorHandleMode(str, Enum):
TERMINATED = "terminated"
CONTINUE_ON_ERROR = "continue-on-error"
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"
class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData):
parent_loop_id: Optional[str] = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
is_parallel: bool = False # open the parallel mode or not
parallel_nums: int = 10 # the numbers of parallel
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error
class IterationStartNodeData(BaseNodeData):

View File

@ -1,12 +1,20 @@
import logging
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
from datetime import datetime, timezone
from typing import Any, cast
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.variables import IntegerSegment
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import (
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.workflow.graph_engine.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]):
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {
"type": "iteration",
"config": {
"is_parallel": False,
"parallel_nums": 10,
"error_handle_mode": ErrorHandleMode.TERMINATED.value,
},
}
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]):
index=0,
pre_iteration_output=None,
)
outputs: list[Any] = []
try:
for _ in range(len(iterator_list_value)):
# run workflow
rst = graph_engine.run()
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
if self.node_data.is_parallel:
futures: list[Future] = []
q = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
current_app._get_current_object(),
q,
iterator_list_value,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
index,
item,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
if isinstance(event, IterationRunNextEvent):
succeeded_count += 1
if succeeded_count == len(futures):
q.put(None)
yield event
if isinstance(event, RunCompletedEvent):
q.put(None)
for f in futures:
if not f.done():
f.cancel()
yield event
if isinstance(event, IterationRunFailedEvent):
q.put(None)
yield event
except Empty:
continue
if isinstance(event, NodeRunSucceededEvent):
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerSegment):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Invalid index variable type: {type(index_variable)}",
)
)
return
metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value
event.route_node_state.node_run_result.metadata = metadata
yield event
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
else:
event = cast(InNodeEvent, event)
yield event
# append to iteration output variable list
current_iteration_output_variable = variable_pool.get(self.node_data.output_selector)
if current_iteration_output_variable is None:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Iteration output variable {self.node_data.output_selector} not found",
)
# wait all threads
wait(futures)
else:
for _ in range(len(iterator_list_value)):
yield from self._run_single_iter(
iterator_list_value,
variable_pool,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
)
return
current_iteration_output = current_iteration_output_variable.to_object()
outputs.append(current_iteration_output)
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# move to next iteration
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"iteration {self.node_id} current index not found")
next_index = current_index_variable.value + 1
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output),
)
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]):
}
return variable_mapping
def _handle_event_metadata(
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
) -> NodeRunStartedEvent | BaseNodeEvent:
"""
add iteration metadata to event.
"""
if not isinstance(event, BaseNodeEvent):
return event
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
return event
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
if self.node_data.is_parallel:
metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
else:
metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
event.route_node_state.node_run_result.metadata = metadata
return event
def _run_single_iter(
self,
iterator_list_value: list[str],
variable_pool: VariablePool,
inputs: dict[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
parallel_mode_run_id: Optional[str] = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration
"""
try:
rst = graph_engine.run()
# get current iteration index
current_index = variable_pool.get([self.node_id, "index"]).value
next_index = int(current_index) + 1
if current_index is None:
raise ValueError(f"iteration {self.node_id} current index not found")
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
continue
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
if self.node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
else:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
else:
event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
outputs.insert(current_index, None)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield metadata_event
current_iteration_output = variable_pool.get(self.node_data.output_selector).value
outputs.insert(current_index, current_iteration_output)
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# move to next iteration
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
)
except Exception as e:
logger.exception(f"Iteration run failed:{str(e)}")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
def _run_single_iter_parallel(
self,
flask_app: Flask,
q: Queue,
iterator_list_value: list[str],
inputs: dict[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
index: int,
item: Any,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration in parallel mode
"""
with flask_app.app_context():
parallel_mode_run_id = uuid.uuid4().hex
graph_engine_copy = graph_engine.create_copy()
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
variable_pool_copy.add([self.node_id, "index"], index)
variable_pool_copy.add([self.node_id, "item"], item)
for event in self._run_single_iter(
iterator_list_value=iterator_list_value,
variable_pool=variable_pool_copy,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine_copy,
iteration_graph=iteration_graph,
parallel_mode_run_id=parallel_mode_run_id,
):
q.put(event)

View File

@ -157,7 +157,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
return lambda x: x.type
case "extension":
return lambda x: x.extension or ""
case "mimetype":
case "mime_type":
return lambda x: x.mime_type or ""
case "transfer_method":
return lambda x: x.transfer_method
@ -295,4 +295,4 @@ def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Seq
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
else:
raise ValueError(f"Invalid order key: {order_by}")
raise InvalidKeyError(f"Invalid order key: {order_by}")

View File

@ -0,0 +1,26 @@
class LLMNodeError(ValueError):
"""Base class for LLM Node errors."""
class VariableNotFoundError(LLMNodeError):
"""Raised when a required variable is not found."""
class InvalidContextStructureError(LLMNodeError):
"""Raised when the context structure is invalid."""
class InvalidVariableTypeError(LLMNodeError):
"""Raised when the variable type is invalid."""
class ModelNotExistError(LLMNodeError):
"""Raised when the specified model does not exist."""
class LLMModeRequiredError(LLMNodeError):
"""Raised when LLM mode is required but not provided."""
class NoPromptFoundError(LLMNodeError):
"""Raised when no prompt is found in the LLM configuration."""

View File

@ -56,6 +56,15 @@ from .entities import (
LLMNodeData,
ModelConfig,
)
from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMModeRequiredError,
LLMNodeError,
ModelNotExistError,
NoPromptFoundError,
VariableNotFoundError,
)
if TYPE_CHECKING:
from core.file.models import File
@ -103,7 +112,7 @@ class LLMNode(BaseNode[LLMNodeData]):
yield event
if context:
node_inputs["#context#"] = context # type: ignore
node_inputs["#context#"] = context
# fetch model config
model_instance, model_config = self._fetch_model_config(self.node_data.model)
@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not query:
raise ValueError("Query not found")
raise VariableNotFoundError("Query not found")
query = query.text
else:
query = None
@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]):
usage = event.usage
finish_reason = event.finish_reason
break
except Exception as e:
except LLMNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
def parse_dict(input_dict: Mapping[str, Any]) -> str:
"""
@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment):
inputs[variable_selector.variable] = ""
inputs[variable_selector.variable] = variable.to_object()
@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in query_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment):
continue
inputs[variable_selector.variable] = variable.to_object()
@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]):
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise ValueError(f"Invalid variable type: {type(variable)}")
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]):
context_str += item + "\n"
else:
if "content" not in item:
raise ValueError(f"Invalid context structure: {item}")
raise InvalidContextStructureError(f"Invalid context structure: {item}")
context_str += item["content"] + "\n"
@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]):
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]):
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
raise LLMModeRequiredError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]):
filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages:
raise ValueError(
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
else:
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping = {}
for variable_selector in variable_selectors:

View File

@ -0,0 +1,50 @@
class ParameterExtractorNodeError(ValueError):
"""Base error for ParameterExtractorNode."""
class InvalidModelTypeError(ParameterExtractorNodeError):
"""Raised when the model is not a Large Language Model."""
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
"""Raised when the model schema is not found."""
class InvalidInvokeResultError(ParameterExtractorNodeError):
"""Raised when the invoke result is invalid."""
class InvalidTextContentTypeError(ParameterExtractorNodeError):
"""Raised when the text content type is invalid."""
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
"""Raised when the number of parameters is invalid."""
class RequiredParameterMissingError(ParameterExtractorNodeError):
"""Raised when a required parameter is missing."""
class InvalidSelectValueError(ParameterExtractorNodeError):
"""Raised when a select value is invalid."""
class InvalidNumberValueError(ParameterExtractorNodeError):
"""Raised when a number value is invalid."""
class InvalidBoolValueError(ParameterExtractorNodeError):
"""Raised when a bool value is invalid."""
class InvalidStringValueError(ParameterExtractorNodeError):
"""Raised when a string value is invalid."""
class InvalidArrayValueError(ParameterExtractorNodeError):
"""Raised when an array value is invalid."""
class InvalidModelModeError(ParameterExtractorNodeError):
"""Raised when the model mode is invalid."""

View File

@ -32,6 +32,21 @@ from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidArrayValueError,
InvalidBoolValueError,
InvalidInvokeResultError,
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
InvalidNumberValueError,
InvalidSelectValueError,
InvalidStringValueError,
InvalidTextContentTypeError,
ModelSchemaNotFoundError,
ParameterExtractorNodeError,
RequiredParameterMissingError,
)
from .prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode):
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(
@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode):
credentials=model_config.credentials,
)
if not model_schema:
raise ValueError("Model schema not found")
raise ModelSchemaNotFoundError("Model schema not found")
# fetch memory
memory = self._fetch_memory(
@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode):
process_data["usage"] = jsonable_encoder(usage)
process_data["tool_call"] = jsonable_encoder(tool_call)
process_data["llm_text"] = text
except Exception as e:
except ParameterExtractorNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode):
try:
result = self._validate_result(data=node_data, result=result or {})
except Exception as e:
except ParameterExtractorNodeError as e:
error = str(e)
# transform result into standard format
@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode):
# handle invoke result
if not isinstance(invoke_result, LLMResult):
raise ValueError(f"Invalid invoke result: {invoke_result}")
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content
if not isinstance(text, str):
raise ValueError(f"Invalid text content type: {type(text)}. Expected str.")
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode):
files=files,
)
else:
raise ValueError(f"Invalid model mode: {model_mode}")
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
def _generate_prompt_engineering_completion_prompt(
self,
@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode):
Validate result.
"""
if len(data.parameters) != len(result):
raise ValueError("Invalid number of parameters")
raise InvalidNumberOfParametersError("Invalid number of parameters")
for parameter in data.parameters:
if parameter.required and parameter.name not in result:
raise ValueError(f"Parameter {parameter.name} is required")
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith("array"):
parameters = result.get(parameter.name)
if not isinstance(parameters, list):
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in parameters:
if nested_type == "number" and not isinstance(item, int | float):
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == "string" and not isinstance(item, str):
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == "object" and not isinstance(item, dict):
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
return result
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode):
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
else:
raise ValueError(f"Model mode {model_mode} not support.")
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
def _get_prompt_engineering_prompt_template(
self,
@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode):
.replace("}γγγ", "")
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
def _calculate_rest_token(
self,
@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode):
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
raise ValueError("Model schema not found")
raise ModelSchemaNotFoundError("Model schema not found")
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)

View File

@ -53,7 +53,7 @@ class ToolNode(BaseNode[ToolNodeData]):
)
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
tool_parameters = tool_runtime.parameters or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,