mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-02 08:17:48 +08:00
Fix: model selecton rule in get_model_config_by_type_and_name (#13569)
### What problem does this PR solve? Fix: model selecton rule in get_model_config_by_type_and_name ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -35,11 +35,12 @@ def get_model_config_by_id(tenant_model_id: int) -> dict:
|
||||
def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
|
||||
if not model_name:
|
||||
raise Exception("Model Name is required")
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, model_name)
|
||||
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, model_name, model_type_val)
|
||||
if not model_config:
|
||||
# model_name in format 'name@factory', split model_name and try again
|
||||
pure_model_name, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
if model_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and pure_model_name == os.getenv("TEI_MODEL", ""):
|
||||
if model_type_val == LLMType.EMBEDDING.value and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and pure_model_name == os.getenv("TEI_MODEL", ""):
|
||||
# configured local embedding model
|
||||
embedding_cfg = settings.EMBEDDING_CFG
|
||||
config_dict = {
|
||||
@ -47,16 +48,22 @@ def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_nam
|
||||
"api_key": embedding_cfg["api_key"],
|
||||
"llm_name": pure_model_name,
|
||||
"api_base": embedding_cfg["base_url"],
|
||||
"model_type": LLMType.EMBEDDING,
|
||||
"model_type": LLMType.EMBEDDING.value,
|
||||
}
|
||||
else:
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name)
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, model_type_val)
|
||||
if not model_config:
|
||||
raise LookupError(f"Tenant Model with name {model_name} not found")
|
||||
raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found")
|
||||
config_dict = model_config.to_dict()
|
||||
else:
|
||||
# model_name without @factory
|
||||
config_dict = model_config.to_dict()
|
||||
config_model_type = config_dict.get("model_type")
|
||||
config_model_type = config_model_type.value if hasattr(config_model_type, "value") else config_model_type
|
||||
if config_model_type != model_type_val:
|
||||
raise LookupError(
|
||||
f"Tenant Model with name {model_name} has type {config_model_type}, expected {model_type_val}"
|
||||
)
|
||||
llm = LLMService.query(llm_name=config_dict["llm_name"])
|
||||
if llm:
|
||||
config_dict["is_tools"] = llm[0].is_tools
|
||||
|
||||
@ -36,12 +36,16 @@ class TenantLLMService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_name):
|
||||
def get_api_key(cls, tenant_id, model_name, model_type=None):
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
|
||||
query_kwargs = {"tenant_id": tenant_id, "llm_name": mdlnm}
|
||||
if model_type_val is not None:
|
||||
query_kwargs["model_type"] = model_type_val
|
||||
if not fid:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||
objs = cls.query(**query_kwargs)
|
||||
else:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
objs = cls.query(**query_kwargs, llm_factory=fid)
|
||||
|
||||
if (not objs) and fid:
|
||||
if fid == "LocalAI":
|
||||
@ -52,7 +56,8 @@ class TenantLLMService(CommonService):
|
||||
mdlnm += "___OpenAI-API"
|
||||
elif fid == "VLLM":
|
||||
mdlnm += "___VLLM"
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
query_kwargs["llm_name"] = mdlnm
|
||||
objs = cls.query(**query_kwargs, llm_factory=fid)
|
||||
if not objs:
|
||||
return None
|
||||
return objs[0]
|
||||
@ -112,10 +117,10 @@ class TenantLLMService(CommonService):
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm, llm_type)
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
if not model_config: # for some cases seems fid mismatch
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm, llm_type)
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""):
|
||||
|
||||
@ -846,6 +846,9 @@ def test_retrieval_test_branch_matrix_unit(monkeypatch):
|
||||
return {"id": "kg-2", "content_with_weight": ""}
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *args, **kwargs: llm_calls.append((args, kwargs)) or SimpleNamespace())
|
||||
monkeypatch.setattr(module, "get_model_config_by_type_and_name", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "chat"})
|
||||
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "chat"})
|
||||
monkeypatch.setattr(module, "get_model_config_by_id", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "embedding"})
|
||||
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: [{"meta": "v"}], raising=False)
|
||||
monkeypatch.setattr(module, "apply_meta_data_filter", _apply_filter)
|
||||
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"meta_data_filter": {"method": "auto"}, "chat_id": "chat-1"}}, raising=False)
|
||||
|
||||
@ -180,6 +180,16 @@ async def _read_sse_text(response):
|
||||
return "".join(chunks)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def auth():
|
||||
return "unit-auth"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_tenant_info():
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_set_conversation_update_create_and_errors(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
@ -532,13 +542,13 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (False, None))
|
||||
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found")))
|
||||
res = _run(module.sequence2txt())
|
||||
assert res["message"] == "Tenant not found"
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", asr_id="")))
|
||||
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default speech2text model is set.")))
|
||||
res = _run(module.sequence2txt())
|
||||
assert res["message"] == "No default speech2text model is set."
|
||||
|
||||
@ -551,8 +561,11 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
|
||||
|
||||
wav_file = _DummyUploadedFile("audio.wav")
|
||||
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", asr_id="asr-model")))
|
||||
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": "asr-model"}))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_tenant_default_model_by_type",
|
||||
lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "asr-model", "model_type": module.LLMType.SPEECH2TEXT.value},
|
||||
)
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncAsr())
|
||||
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("remove failed")))
|
||||
res = _run(module.sequence2txt())
|
||||
@ -593,11 +606,11 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
|
||||
def test_tts_request_parse_entry(monkeypatch):
|
||||
module = _load_conversation_module(monkeypatch)
|
||||
_set_request_json(monkeypatch, module, {"text": "A。B"})
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (False, None))
|
||||
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found")))
|
||||
res = _run(module.tts())
|
||||
assert res["message"] == "Tenant not found"
|
||||
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", tts_id="")))
|
||||
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default tts model is set.")))
|
||||
res = _run(module.tts())
|
||||
assert res["message"] == "No default tts model is set."
|
||||
|
||||
@ -607,8 +620,11 @@ def test_tts_request_parse_entry(monkeypatch):
|
||||
return []
|
||||
yield f"chunk-{txt}".encode("utf-8")
|
||||
|
||||
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", tts_id="tts-x")))
|
||||
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": model_name}))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_tenant_default_model_by_type",
|
||||
lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "tts-x", "model_type": module.LLMType.TTS.value},
|
||||
)
|
||||
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
|
||||
resp = _run(module.tts())
|
||||
assert resp.mimetype == "audio/mpeg"
|
||||
@ -770,7 +786,11 @@ def test_mindmap_and_related_questions_matrix_unit(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(module, "LLMBundle", _fake_bundle)
|
||||
monkeypatch.setattr(module, "load_prompt", lambda name: f"prompt-{name}")
|
||||
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": model_name}))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_model_config_by_type_and_name",
|
||||
lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "chat-x", "model_type": module.LLMType.CHAT.value},
|
||||
)
|
||||
_set_request_json(monkeypatch, module, {"question": "solar", "search_id": "search-1"})
|
||||
res = _run(module.related_questions.__wrapped__())
|
||||
assert res["code"] == 0
|
||||
|
||||
@ -1064,6 +1064,11 @@ def test_unbind_task_branch_matrix(monkeypatch):
|
||||
def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch):
|
||||
module = _load_kb_module(monkeypatch)
|
||||
route = inspect.unwrap(module.check_embedding)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_model_config_by_type_and_name",
|
||||
lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "emb-1", "model_type": module.LLMType.EMBEDDING.value},
|
||||
)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-1")))
|
||||
monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx")
|
||||
|
||||
@ -1228,6 +1233,11 @@ def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch):
|
||||
def test_check_embedding_error_and_empty_sample_paths_unit(monkeypatch):
|
||||
module = _load_kb_module(monkeypatch)
|
||||
route = inspect.unwrap(module.check_embedding)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_model_config_by_type_and_name",
|
||||
lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "emb-1", "model_type": module.LLMType.EMBEDDING.value},
|
||||
)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-1")))
|
||||
monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx")
|
||||
monkeypatch.setattr(module.random, "sample", lambda population, k: list(population)[:k])
|
||||
|
||||
Reference in New Issue
Block a user