mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 15:26:21 +08:00
275 lines
12 KiB
Python
275 lines
12 KiB
Python
import logging
|
|
from collections.abc import Mapping, Sequence
|
|
from typing import TYPE_CHECKING, Any, Literal
|
|
|
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
|
from dify_graph.entities import GraphInitParams
|
|
from dify_graph.enums import (
|
|
NodeType,
|
|
WorkflowNodeExecutionMetadataKey,
|
|
WorkflowNodeExecutionStatus,
|
|
)
|
|
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
|
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
|
from dify_graph.node_events import NodeRunResult
|
|
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
|
from dify_graph.nodes.base.node import Node
|
|
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
|
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
|
from dify_graph.variables import (
|
|
ArrayFileSegment,
|
|
FileSegment,
|
|
StringSegment,
|
|
)
|
|
from dify_graph.variables.segments import ArrayObjectSegment
|
|
|
|
from .entities import KnowledgeRetrievalNodeData
|
|
from .exc import (
|
|
KnowledgeRetrievalNodeError,
|
|
RateLimitExceededError,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from dify_graph.file.models import File
|
|
from dify_graph.runtime import GraphRuntimeState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
|
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
|
|
|
# Instance attributes specific to LLMNode.
|
|
# Output variable for file
|
|
_file_outputs: list["File"]
|
|
|
|
_llm_file_saver: LLMFileSaver
|
|
|
|
def __init__(
|
|
self,
|
|
id: str,
|
|
config: Mapping[str, Any],
|
|
graph_init_params: "GraphInitParams",
|
|
graph_runtime_state: "GraphRuntimeState",
|
|
rag_retrieval: RAGRetrievalProtocol,
|
|
*,
|
|
llm_file_saver: LLMFileSaver | None = None,
|
|
):
|
|
super().__init__(
|
|
id=id,
|
|
config=config,
|
|
graph_init_params=graph_init_params,
|
|
graph_runtime_state=graph_runtime_state,
|
|
)
|
|
# LLM file outputs, used for MultiModal outputs.
|
|
self._file_outputs = []
|
|
self._rag_retrieval = rag_retrieval
|
|
|
|
if llm_file_saver is None:
|
|
llm_file_saver = FileSaverImpl(
|
|
user_id=graph_init_params.user_id,
|
|
tenant_id=graph_init_params.tenant_id,
|
|
)
|
|
self._llm_file_saver = llm_file_saver
|
|
|
|
@classmethod
|
|
def version(cls):
|
|
return "1"
|
|
|
|
def _run(self) -> NodeRunResult:
|
|
usage = LLMUsage.empty_usage()
|
|
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
inputs={},
|
|
process_data={},
|
|
outputs={},
|
|
metadata={},
|
|
llm_usage=usage,
|
|
)
|
|
variables: dict[str, Any] = {}
|
|
# extract variables
|
|
if self._node_data.query_variable_selector:
|
|
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
|
if not isinstance(variable, StringSegment):
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs={},
|
|
error="Query variable is not string type.",
|
|
)
|
|
query = variable.value
|
|
variables["query"] = query
|
|
|
|
if self._node_data.query_attachment_selector:
|
|
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector)
|
|
if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment):
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs={},
|
|
error="Attachments variable is not array file or file type.",
|
|
)
|
|
if isinstance(variable, ArrayFileSegment):
|
|
variables["attachments"] = variable.value
|
|
else:
|
|
variables["attachments"] = [variable.value]
|
|
|
|
try:
|
|
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
|
|
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
inputs=variables,
|
|
process_data={"usage": jsonable_encoder(usage)},
|
|
outputs=outputs, # type: ignore
|
|
metadata={
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
|
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
|
},
|
|
llm_usage=usage,
|
|
)
|
|
except RateLimitExceededError as e:
|
|
logger.warning(e, exc_info=True)
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs=variables,
|
|
error=str(e),
|
|
error_type=type(e).__name__,
|
|
llm_usage=usage,
|
|
)
|
|
except KnowledgeRetrievalNodeError as e:
|
|
logger.warning("Error when running knowledge retrieval node", exc_info=True)
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs=variables,
|
|
error=str(e),
|
|
error_type=type(e).__name__,
|
|
llm_usage=usage,
|
|
)
|
|
# Temporary handle all exceptions from DatasetRetrieval class here.
|
|
except Exception as e:
|
|
logger.warning(e, exc_info=True)
|
|
return NodeRunResult(
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
inputs=variables,
|
|
error=str(e),
|
|
error_type=type(e).__name__,
|
|
llm_usage=usage,
|
|
)
|
|
|
|
def _fetch_dataset_retriever(
|
|
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
|
|
) -> tuple[list[Source], LLMUsage]:
|
|
dataset_ids = node_data.dataset_ids
|
|
query = variables.get("query")
|
|
attachments = variables.get("attachments")
|
|
retrieval_resource_list = []
|
|
|
|
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = "disabled"
|
|
if node_data.metadata_filtering_mode is not None:
|
|
metadata_filtering_mode = node_data.metadata_filtering_mode
|
|
|
|
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
|
# fetch model config
|
|
if node_data.single_retrieval_config is None:
|
|
raise ValueError("single_retrieval_config is required for single retrieval mode")
|
|
model = node_data.single_retrieval_config.model
|
|
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
|
request=KnowledgeRetrievalRequest(
|
|
tenant_id=self.tenant_id,
|
|
user_id=self.user_id,
|
|
app_id=self.app_id,
|
|
user_from=self.user_from.value,
|
|
dataset_ids=dataset_ids,
|
|
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
|
|
completion_params=model.completion_params,
|
|
model_provider=model.provider,
|
|
model_mode=model.mode,
|
|
model_name=model.name,
|
|
metadata_model_config=node_data.metadata_model_config,
|
|
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
|
metadata_filtering_mode=metadata_filtering_mode,
|
|
query=query,
|
|
)
|
|
)
|
|
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
|
if node_data.multiple_retrieval_config is None:
|
|
raise ValueError("multiple_retrieval_config is required")
|
|
reranking_model = None
|
|
weights = None
|
|
match node_data.multiple_retrieval_config.reranking_mode:
|
|
case "reranking_model":
|
|
if node_data.multiple_retrieval_config.reranking_model:
|
|
reranking_model = {
|
|
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
|
|
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
|
|
}
|
|
else:
|
|
reranking_model = None
|
|
weights = None
|
|
case "weighted_score":
|
|
if node_data.multiple_retrieval_config.weights is None:
|
|
raise ValueError("weights is required")
|
|
reranking_model = None
|
|
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
|
|
weights = {
|
|
"vector_setting": {
|
|
"vector_weight": vector_setting.vector_weight,
|
|
"embedding_provider_name": vector_setting.embedding_provider_name,
|
|
"embedding_model_name": vector_setting.embedding_model_name,
|
|
},
|
|
"keyword_setting": {
|
|
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
|
|
},
|
|
}
|
|
case _:
|
|
# Handle any other reranking_mode values
|
|
reranking_model = None
|
|
weights = None
|
|
|
|
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
|
request=KnowledgeRetrievalRequest(
|
|
app_id=self.app_id,
|
|
tenant_id=self.tenant_id,
|
|
user_id=self.user_id,
|
|
user_from=self.user_from.value,
|
|
dataset_ids=dataset_ids,
|
|
query=query,
|
|
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
|
|
top_k=node_data.multiple_retrieval_config.top_k,
|
|
score_threshold=node_data.multiple_retrieval_config.score_threshold
|
|
if node_data.multiple_retrieval_config.score_threshold is not None
|
|
else 0.0,
|
|
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
|
reranking_model=reranking_model,
|
|
weights=weights,
|
|
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
|
metadata_model_config=node_data.metadata_model_config,
|
|
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
|
metadata_filtering_mode=metadata_filtering_mode,
|
|
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
|
)
|
|
)
|
|
|
|
usage = self._rag_retrieval.llm_usage
|
|
return retrieval_resource_list, usage
|
|
|
|
@classmethod
|
|
def _extract_variable_selector_to_variable_mapping(
|
|
cls,
|
|
*,
|
|
graph_config: Mapping[str, Any],
|
|
node_id: str,
|
|
node_data: Mapping[str, Any],
|
|
) -> Mapping[str, Sequence[str]]:
|
|
# graph_config is not used in this node type
|
|
# Create typed NodeData from dict
|
|
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
|
|
|
variable_mapping = {}
|
|
if typed_node_data.query_variable_selector:
|
|
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
|
if typed_node_data.query_attachment_selector:
|
|
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
|
|
return variable_mapping
|