Files
dify/api/dify_graph/repositories/rag_retrieval_protocol.py

109 lines
5.6 KiB
Python

from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from dify_graph.model_runtime.entities import LLMUsage
from dify_graph.nodes.knowledge_retrieval.entities import MetadataFilteringCondition
from dify_graph.nodes.llm.entities import ModelConfig
class SourceChildChunk(BaseModel):
id: str = Field(default="", description="Child chunk ID")
content: str = Field(default="", description="Child chunk content")
position: int = Field(default=0, description="Child chunk position")
score: float = Field(default=0.0, description="Child chunk relevance score")
class SourceMetadata(BaseModel):
source: str = Field(
default="knowledge",
serialization_alias="_source",
description="Data source identifier",
)
dataset_id: str = Field(description="Dataset unique identifier")
dataset_name: str = Field(description="Dataset display name")
document_id: str = Field(description="Document unique identifier")
document_name: str = Field(description="Document display name")
data_source_type: str = Field(description="Type of data source")
segment_id: str | None = Field(default=None, description="Segment unique identifier")
retriever_from: str = Field(default="workflow", description="Retriever source context")
score: float = Field(default=0.0, description="Retrieval relevance score")
child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks")
segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved")
segment_word_count: int | None = Field(default=0, description="Word count of the segment")
segment_position: int | None = Field(default=0, description="Position of segment in document")
segment_index_node_hash: str | None = Field(default=None, description="Hash of index node for the segment")
doc_metadata: dict[str, Any] | None = Field(default=None, description="Additional document metadata")
position: int | None = Field(default=0, description="Position of the document in the dataset")
class Config:
populate_by_name = True
class Source(BaseModel):
metadata: SourceMetadata = Field(description="Source metadata information")
title: str = Field(description="Document title")
files: list[Any] | None = Field(default=None, description="Associated file references")
content: str | None = Field(description="Segment content text")
summary: str | None = Field(default=None, description="Content summary if available")
class KnowledgeRetrievalRequest(BaseModel):
tenant_id: str = Field(description="Tenant unique identifier")
user_id: str = Field(description="User unique identifier")
app_id: str = Field(description="Application unique identifier")
user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')")
dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from")
query: str | None = Field(default=None, description="Query text for knowledge retrieval")
retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'")
model_provider: str | None = Field(default=None, description="Model provider name (e.g., 'openai', 'anthropic')")
completion_params: dict[str, Any] | None = Field(
default=None, description="Model completion parameters (e.g., temperature, max_tokens)"
)
model_mode: str | None = Field(default=None, description="Model mode (e.g., 'chat', 'completion')")
model_name: str | None = Field(default=None, description="Model name (e.g., 'gpt-4', 'claude-3-opus')")
metadata_model_config: ModelConfig | None = Field(
default=None, description="Model config for metadata-based filtering"
)
metadata_filtering_conditions: MetadataFilteringCondition | None = Field(
default=None, description="Conditions for filtering by metadata"
)
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = Field(
default="disabled", description="Metadata filtering mode: 'disabled', 'automatic', or 'manual'"
)
top_k: int = Field(default=0, description="Number of top results to return")
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
class RAGRetrievalProtocol(Protocol):
"""Protocol for RAG-based knowledge retrieval implementations.
Implementations of this protocol handle knowledge retrieval from datasets
including rate limiting, dataset filtering, and document retrieval.
"""
@property
def llm_usage(self) -> LLMUsage:
"""Return accumulated LLM usage for retrieval operations."""
...
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
"""Retrieve knowledge from datasets based on the provided request.
Args:
request: Knowledge retrieval request with search parameters
Returns:
List of sources matching the search criteria
Raises:
RateLimitExceededError: If rate limit is exceeded
ModelNotExistError: If specified model doesn't exist
"""
...