mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
refactor: move the embedding to the rag module and abstract the rerank runner for extension (#9423)
This commit is contained in:
26
api/core/rag/rerank/rerank_base.py
Normal file
26
api/core/rag/rerank/rerank_base.py
Normal file
@ -0,0 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class BaseRerankRunner(ABC):
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
documents: list[Document],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Run rerank model
|
||||
:param query: search query
|
||||
:param documents: documents for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id if needed
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
16
api/core/rag/rerank/rerank_factory.py
Normal file
16
api/core/rag/rerank/rerank_factory.py
Normal file
@ -0,0 +1,16 @@
|
||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.rerank.weight_rerank import WeightRerankRunner
|
||||
|
||||
|
||||
class RerankRunnerFactory:
|
||||
@staticmethod
|
||||
def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner:
|
||||
match runner_type:
|
||||
case RerankMode.RERANKING_MODEL.value:
|
||||
return RerankModelRunner(*args, **kwargs)
|
||||
case RerankMode.WEIGHTED_SCORE.value:
|
||||
return WeightRerankRunner(*args, **kwargs)
|
||||
case _:
|
||||
raise ValueError(f"Unknown runner type: {runner_type}")
|
||||
@ -2,9 +2,10 @@ from typing import Optional
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||
|
||||
|
||||
class RerankModelRunner:
|
||||
class RerankModelRunner(BaseRerankRunner):
|
||||
def __init__(self, rerank_model_instance: ModelInstance) -> None:
|
||||
self.rerank_model_instance = rerank_model_instance
|
||||
|
||||
|
||||
@ -4,15 +4,16 @@ from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.entity.weight import VectorSetting, Weights
|
||||
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||
|
||||
|
||||
class WeightRerankRunner:
|
||||
class WeightRerankRunner(BaseRerankRunner):
|
||||
def __init__(self, tenant_id: str, weights: Weights) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.weights = weights
|
||||
|
||||
Reference in New Issue
Block a user