Fix: model selecton rule in get_model_config_by_type_and_name (#13569)

### What problem does this PR solve?

Fix: model selecton rule in get_model_config_by_type_and_name

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Magicbook1108
2026-03-13 19:46:13 +08:00
committed by GitHub
parent cb49cd30c4
commit 161659becc
5 changed files with 65 additions and 20 deletions

View File

@ -35,11 +35,12 @@ def get_model_config_by_id(tenant_model_id: int) -> dict:
def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
if not model_name:
raise Exception("Model Name is required")
model_config = TenantLLMService.get_api_key(tenant_id, model_name)
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
model_config = TenantLLMService.get_api_key(tenant_id, model_name, model_type_val)
if not model_config:
# model_name in format 'name@factory', split model_name and try again
pure_model_name, fid = TenantLLMService.split_model_name_and_factory(model_name)
if model_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and pure_model_name == os.getenv("TEI_MODEL", ""):
if model_type_val == LLMType.EMBEDDING.value and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and pure_model_name == os.getenv("TEI_MODEL", ""):
# configured local embedding model
embedding_cfg = settings.EMBEDDING_CFG
config_dict = {
@ -47,16 +48,22 @@ def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_nam
"api_key": embedding_cfg["api_key"],
"llm_name": pure_model_name,
"api_base": embedding_cfg["base_url"],
"model_type": LLMType.EMBEDDING,
"model_type": LLMType.EMBEDDING.value,
}
else:
model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name)
model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, model_type_val)
if not model_config:
raise LookupError(f"Tenant Model with name {model_name} not found")
raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found")
config_dict = model_config.to_dict()
else:
# model_name without @factory
config_dict = model_config.to_dict()
config_model_type = config_dict.get("model_type")
config_model_type = config_model_type.value if hasattr(config_model_type, "value") else config_model_type
if config_model_type != model_type_val:
raise LookupError(
f"Tenant Model with name {model_name} has type {config_model_type}, expected {model_type_val}"
)
llm = LLMService.query(llm_name=config_dict["llm_name"])
if llm:
config_dict["is_tools"] = llm[0].is_tools

View File

@ -36,12 +36,16 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
def get_api_key(cls, tenant_id, model_name, model_type=None):
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
query_kwargs = {"tenant_id": tenant_id, "llm_name": mdlnm}
if model_type_val is not None:
query_kwargs["model_type"] = model_type_val
if not fid:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
objs = cls.query(**query_kwargs)
else:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
objs = cls.query(**query_kwargs, llm_factory=fid)
if (not objs) and fid:
if fid == "LocalAI":
@ -52,7 +56,8 @@ class TenantLLMService(CommonService):
mdlnm += "___OpenAI-API"
elif fid == "VLLM":
mdlnm += "___VLLM"
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
query_kwargs["llm_name"] = mdlnm
objs = cls.query(**query_kwargs, llm_factory=fid)
if not objs:
return None
return objs[0]
@ -112,10 +117,10 @@ class TenantLLMService(CommonService):
else:
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
model_config = cls.get_api_key(tenant_id, mdlnm, llm_type)
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if not model_config: # for some cases seems fid mismatch
model_config = cls.get_api_key(tenant_id, mdlnm)
model_config = cls.get_api_key(tenant_id, mdlnm, llm_type)
if model_config:
model_config = model_config.to_dict()
elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""):

View File

@ -846,6 +846,9 @@ def test_retrieval_test_branch_matrix_unit(monkeypatch):
return {"id": "kg-2", "content_with_weight": ""}
monkeypatch.setattr(module, "LLMBundle", lambda *args, **kwargs: llm_calls.append((args, kwargs)) or SimpleNamespace())
monkeypatch.setattr(module, "get_model_config_by_type_and_name", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "chat"})
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "chat"})
monkeypatch.setattr(module, "get_model_config_by_id", lambda *_args, **_kwargs: {"llm_name": "stub-model", "model_type": "embedding"})
monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kb_ids: [{"meta": "v"}], raising=False)
monkeypatch.setattr(module, "apply_meta_data_filter", _apply_filter)
monkeypatch.setattr(module.SearchService, "get_detail", lambda _sid: {"search_config": {"meta_data_filter": {"method": "auto"}, "chat_id": "chat-1"}}, raising=False)

View File

@ -180,6 +180,16 @@ async def _read_sse_text(response):
return "".join(chunks)
@pytest.fixture(scope="session")
def auth():
return "unit-auth"
@pytest.fixture(scope="session", autouse=True)
def set_tenant_info():
return None
@pytest.mark.p2
def test_set_conversation_update_create_and_errors(monkeypatch):
module = _load_conversation_module(monkeypatch)
@ -532,13 +542,13 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (False, None))
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found")))
res = _run(module.sequence2txt())
assert res["message"] == "Tenant not found"
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", asr_id="")))
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default speech2text model is set.")))
res = _run(module.sequence2txt())
assert res["message"] == "No default speech2text model is set."
@ -551,8 +561,11 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
wav_file = _DummyUploadedFile("audio.wav")
monkeypatch.setattr(module, "request", _DummyRequest(form={"stream": "false"}, files={"file": wav_file}))
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", asr_id="asr-model")))
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": "asr-model"}))
monkeypatch.setattr(
module,
"get_tenant_default_model_by_type",
lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "asr-model", "model_type": module.LLMType.SPEECH2TEXT.value},
)
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncAsr())
monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("remove failed")))
res = _run(module.sequence2txt())
@ -593,11 +606,11 @@ def test_sequence2txt_validation_and_transcription_paths(monkeypatch):
def test_tts_request_parse_entry(monkeypatch):
module = _load_conversation_module(monkeypatch)
_set_request_json(monkeypatch, module, {"text": "A。B"})
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (False, None))
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found")))
res = _run(module.tts())
assert res["message"] == "Tenant not found"
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", tts_id="")))
monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default tts model is set.")))
res = _run(module.tts())
assert res["message"] == "No default tts model is set."
@ -607,8 +620,11 @@ def test_tts_request_parse_entry(monkeypatch):
return []
yield f"chunk-{txt}".encode("utf-8")
monkeypatch.setattr(sys.modules["api.db.joint_services.tenant_model_service"].TenantService, "get_by_id", lambda _uid: (True, SimpleNamespace(tenant_id="tenant-1", tts_id="tts-x")))
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": model_name}))
monkeypatch.setattr(
module,
"get_tenant_default_model_by_type",
lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "tts-x", "model_type": module.LLMType.TTS.value},
)
monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk())
resp = _run(module.tts())
assert resp.mimetype == "audio/mpeg"
@ -770,7 +786,11 @@ def test_mindmap_and_related_questions_matrix_unit(monkeypatch):
monkeypatch.setattr(module, "LLMBundle", _fake_bundle)
monkeypatch.setattr(module, "load_prompt", lambda name: f"prompt-{name}")
monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda tenant_id, model_name: SimpleNamespace(to_dict=lambda: {"llm_factory": "test", "llm_name": model_name}))
monkeypatch.setattr(
module,
"get_model_config_by_type_and_name",
lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "chat-x", "model_type": module.LLMType.CHAT.value},
)
_set_request_json(monkeypatch, module, {"question": "solar", "search_id": "search-1"})
res = _run(module.related_questions.__wrapped__())
assert res["code"] == 0

View File

@ -1064,6 +1064,11 @@ def test_unbind_task_branch_matrix(monkeypatch):
def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch):
module = _load_kb_module(monkeypatch)
route = inspect.unwrap(module.check_embedding)
monkeypatch.setattr(
module,
"get_model_config_by_type_and_name",
lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "emb-1", "model_type": module.LLMType.EMBEDDING.value},
)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-1")))
monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx")
@ -1228,6 +1233,11 @@ def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch):
def test_check_embedding_error_and_empty_sample_paths_unit(monkeypatch):
module = _load_kb_module(monkeypatch)
route = inspect.unwrap(module.check_embedding)
monkeypatch.setattr(
module,
"get_model_config_by_type_and_name",
lambda *_args, **_kwargs: {"llm_factory": "test", "llm_name": "emb-1", "model_type": module.LLMType.EMBEDDING.value},
)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, SimpleNamespace(tenant_id="tenant-1")))
monkeypatch.setattr(module.search, "index_name", lambda _tenant_id: "idx")
monkeypatch.setattr(module.random, "sample", lambda population, k: list(population)[:k])