Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly
2024-11-07 17:06:29 +08:00
46 changed files with 1342 additions and 1425 deletions

View File

@ -0,0 +1,22 @@
class IterationNodeError(ValueError):
"""Base class for iteration node errors."""
class IteratorVariableNotFoundError(IterationNodeError):
"""Raised when the iterator variable is not found."""
class InvalidIteratorValueError(IterationNodeError):
"""Raised when the iterator value is invalid."""
class StartNodeIdNotFoundError(IterationNodeError):
"""Raised when the start node ID is not found."""
class IterationGraphNotFoundError(IterationNodeError):
"""Raised when the iteration graph is not found."""
class IterationIndexNotFoundError(IterationNodeError):
"""Raised when the iteration index is not found."""

View File

@ -38,6 +38,15 @@ from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from models.workflow import WorkflowNodeExecutionStatus
from .exc import (
InvalidIteratorValueError,
IterationGraphNotFoundError,
IterationIndexNotFoundError,
IterationNodeError,
IteratorVariableNotFoundError,
StartNodeIdNotFoundError,
)
if TYPE_CHECKING:
from core.workflow.graph_engine.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@ -69,7 +78,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not iterator_list_segment:
raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
if len(iterator_list_segment.value) == 0:
yield RunCompletedEvent(
@ -83,14 +92,14 @@ class IterationNode(BaseNode[IterationNodeData]):
iterator_list_value = iterator_list_segment.to_object()
if not isinstance(iterator_list_value, list):
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in iteration {self.node_id} not found")
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self.node_data.start_node_id
@ -98,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
if not iteration_graph:
raise ValueError("iteration graph not found")
raise IterationGraphNotFoundError("iteration graph not found")
variable_pool = self.graph_runtime_state.variable_pool
@ -222,9 +231,9 @@ class IterationNode(BaseNode[IterationNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
)
)
except Exception as e:
except IterationNodeError as e:
# iteration run failed
logger.exception("Iteration run failed")
logger.warning("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -272,7 +281,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
if not iteration_graph:
raise ValueError("iteration graph not found")
raise IterationGraphNotFoundError("iteration graph not found")
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
if sub_node_config.get("data", {}).get("iteration_id") != node_id:
@ -357,7 +366,7 @@ class IterationNode(BaseNode[IterationNodeData]):
next_index = int(current_index) + 1
if current_index is None:
raise ValueError(f"iteration {self.node_id} current index not found")
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
@ -484,8 +493,8 @@ class IterationNode(BaseNode[IterationNodeData]):
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
)
except Exception as e:
logger.exception(f"Iteration run failed:{str(e)}")
except IterationNodeError as e:
logger.warning(f"Iteration run failed:{str(e)}")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,

View File

@ -0,0 +1,18 @@
class KnowledgeRetrievalNodeError(ValueError):
"""Base class for KnowledgeRetrievalNode errors."""
class ModelNotExistError(KnowledgeRetrievalNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeRetrievalNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
"""Raised when the model provider quota is exceeded."""

View File

@ -8,7 +8,6 @@ from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -18,11 +17,19 @@ from core.variables import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
from .entities import KnowledgeRetrievalNodeData
from .exc import (
KnowledgeRetrievalNodeError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
ModelNotSupportedError,
ModelQuotaExceededError,
)
logger = logging.getLogger(__name__)
default_retrieval_model = {
@ -61,8 +68,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
)
except Exception as e:
logger.exception("Error when running knowledge retrieval node")
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node")
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
@ -295,14 +302,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data.single_retrieval_config.model.completion_params
@ -314,12 +321,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
# get model mode
model_mode = node_data.single_retrieval_config.model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
raise ModelNotExistError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,

View File

@ -0,0 +1,6 @@
class QuestionClassifierNodeError(ValueError):
"""Base class for QuestionClassifierNode errors."""
class InvalidModelTypeError(QuestionClassifierNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.llm_generator.output_parser.errors import OutputParserError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
@ -24,6 +25,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models.workflow import WorkflowNodeExecutionStatus
from .entities import QuestionClassifierNodeData
from .exc import InvalidModelTypeError
from .template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
@ -124,7 +126,7 @@ class QuestionClassifierNode(LLMNode):
category_name = classes_map[category_id_result]
category_id = category_id_result
except Exception:
except OutputParserError:
logging.error(f"Failed to parse result text: {result_text}")
try:
process_data = {
@ -309,4 +311,4 @@ class QuestionClassifierNode(LLMNode):
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

View File

@ -0,0 +1,16 @@
class ToolNodeError(ValueError):
"""Base exception for tool node errors."""
pass
class ToolParameterError(ToolNodeError):
"""Exception raised for errors in tool parameters."""
pass
class ToolFileError(ToolNodeError):
"""Exception raised for errors related to tool files."""
pass

View File

@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.models import File, FileTransferMethod, FileType
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
@ -19,12 +19,18 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.tools import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ToolNodeData
from .exc import (
ToolFileError,
ToolNodeError,
ToolParameterError,
)
class ToolNode(BaseNode[ToolNodeData]):
"""
@ -49,7 +55,7 @@ class ToolNode(BaseNode[ToolNodeData]):
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
except Exception as e:
except ToolNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -58,7 +64,6 @@ class ToolNode(BaseNode[ToolNodeData]):
error=f"Failed to get tool runtime: {str(e)}",
)
)
return
# get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
@ -85,7 +90,7 @@ class ToolNode(BaseNode[ToolNodeData]):
app_id=self.app_id,
# TODO: conversation id and message id
)
except Exception as e:
except ToolNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -94,7 +99,6 @@ class ToolNode(BaseNode[ToolNodeData]):
error=f"Failed to invoke tool: {str(e)}",
)
)
return
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
@ -131,14 +135,13 @@ class ToolNode(BaseNode[ToolNodeData]):
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ValueError(f"variable {tool_input.value} not exists")
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ValueError(f"unknown tool input type '{tool_input.type}'")
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
result[parameter_name] = parameter_value
return result
@ -187,7 +190,7 @@ class ToolNode(BaseNode[ToolNodeData]):
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
files.append(
File(
@ -212,8 +215,7 @@ class ToolNode(BaseNode[ToolNodeData]):
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
raise ToolFileError(f"tool file {tool_file_id} not exists")
files.append(
File(
tenant_id=self.tenant_id,