mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-19 11:45:10 +08:00
refactor: introduce common normalize method in rerank base class (#12550)
### What problem does this PR solve? introduce common normalize method in rerank base class ### Type of change - [x] Refactoring
This commit is contained in:
@ -36,6 +36,22 @@ class Base(ABC):
|
||||
def similarity(self, query: str, texts: list):
|
||||
raise NotImplementedError("Please implement encode method!")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_rank(rank: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Normalize rank values to the range 0 to 1.
|
||||
Avoids division by zero if all ranks are identical.
|
||||
"""
|
||||
min_rank = np.min(rank)
|
||||
max_rank = np.max(rank)
|
||||
|
||||
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
||||
rank = (rank - min_rank) / (max_rank - min_rank)
|
||||
else:
|
||||
rank = np.zeros_like(rank)
|
||||
|
||||
return rank
|
||||
|
||||
|
||||
class JinaRerank(Base):
|
||||
_FACTORY_NAME = "Jina"
|
||||
@ -121,15 +137,7 @@ class LocalAIRerank(Base):
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
# Normalize the rank values to the range 0 to 1
|
||||
min_rank = np.min(rank)
|
||||
max_rank = np.max(rank)
|
||||
|
||||
# Avoid division by zero if all ranks are identical
|
||||
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
||||
rank = (rank - min_rank) / (max_rank - min_rank)
|
||||
else:
|
||||
rank = np.zeros_like(rank)
|
||||
rank = Base._normalize_rank(rank)
|
||||
|
||||
return rank, token_count
|
||||
|
||||
@ -215,15 +223,7 @@ class OpenAI_APIRerank(Base):
|
||||
except Exception as _e:
|
||||
log_exception(_e, res)
|
||||
|
||||
# Normalize the rank values to the range 0 to 1
|
||||
min_rank = np.min(rank)
|
||||
max_rank = np.max(rank)
|
||||
|
||||
# Avoid division by zero if all ranks are identical
|
||||
if not np.isclose(min_rank, max_rank, atol=1e-3):
|
||||
rank = (rank - min_rank) / (max_rank - min_rank)
|
||||
else:
|
||||
rank = np.zeros_like(rank)
|
||||
rank = Base._normalize_rank(rank)
|
||||
|
||||
return rank, token_count
|
||||
|
||||
|
||||
Reference in New Issue
Block a user