mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-05 01:37:46 +08:00
Fix: model selecton rule in get_model_config_by_type_and_name (#13569)
### What problem does this PR solve? Fix: model selecton rule in get_model_config_by_type_and_name ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -35,11 +35,12 @@ def get_model_config_by_id(tenant_model_id: int) -> dict:
|
||||
def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str):
|
||||
if not model_name:
|
||||
raise Exception("Model Name is required")
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, model_name)
|
||||
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, model_name, model_type_val)
|
||||
if not model_config:
|
||||
# model_name in format 'name@factory', split model_name and try again
|
||||
pure_model_name, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
if model_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and pure_model_name == os.getenv("TEI_MODEL", ""):
|
||||
if model_type_val == LLMType.EMBEDDING.value and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and pure_model_name == os.getenv("TEI_MODEL", ""):
|
||||
# configured local embedding model
|
||||
embedding_cfg = settings.EMBEDDING_CFG
|
||||
config_dict = {
|
||||
@ -47,16 +48,22 @@ def get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_nam
|
||||
"api_key": embedding_cfg["api_key"],
|
||||
"llm_name": pure_model_name,
|
||||
"api_base": embedding_cfg["base_url"],
|
||||
"model_type": LLMType.EMBEDDING,
|
||||
"model_type": LLMType.EMBEDDING.value,
|
||||
}
|
||||
else:
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name)
|
||||
model_config = TenantLLMService.get_api_key(tenant_id, pure_model_name, model_type_val)
|
||||
if not model_config:
|
||||
raise LookupError(f"Tenant Model with name {model_name} not found")
|
||||
raise LookupError(f"Tenant Model with name {model_name} and type {model_type_val} not found")
|
||||
config_dict = model_config.to_dict()
|
||||
else:
|
||||
# model_name without @factory
|
||||
config_dict = model_config.to_dict()
|
||||
config_model_type = config_dict.get("model_type")
|
||||
config_model_type = config_model_type.value if hasattr(config_model_type, "value") else config_model_type
|
||||
if config_model_type != model_type_val:
|
||||
raise LookupError(
|
||||
f"Tenant Model with name {model_name} has type {config_model_type}, expected {model_type_val}"
|
||||
)
|
||||
llm = LLMService.query(llm_name=config_dict["llm_name"])
|
||||
if llm:
|
||||
config_dict["is_tools"] = llm[0].is_tools
|
||||
|
||||
@ -36,12 +36,16 @@ class TenantLLMService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_name):
|
||||
def get_api_key(cls, tenant_id, model_name, model_type=None):
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
model_type_val = model_type.value if hasattr(model_type, "value") else model_type
|
||||
query_kwargs = {"tenant_id": tenant_id, "llm_name": mdlnm}
|
||||
if model_type_val is not None:
|
||||
query_kwargs["model_type"] = model_type_val
|
||||
if not fid:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||
objs = cls.query(**query_kwargs)
|
||||
else:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
objs = cls.query(**query_kwargs, llm_factory=fid)
|
||||
|
||||
if (not objs) and fid:
|
||||
if fid == "LocalAI":
|
||||
@ -52,7 +56,8 @@ class TenantLLMService(CommonService):
|
||||
mdlnm += "___OpenAI-API"
|
||||
elif fid == "VLLM":
|
||||
mdlnm += "___VLLM"
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
query_kwargs["llm_name"] = mdlnm
|
||||
objs = cls.query(**query_kwargs, llm_factory=fid)
|
||||
if not objs:
|
||||
return None
|
||||
return objs[0]
|
||||
@ -112,10 +117,10 @@ class TenantLLMService(CommonService):
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm, llm_type)
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
if not model_config: # for some cases seems fid mismatch
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm, llm_type)
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
elif llm_type == LLMType.EMBEDDING and fid == "Builtin" and "tei-" in os.getenv("COMPOSE_PROFILES", "") and mdlnm == os.getenv("TEI_MODEL", ""):
|
||||
|
||||
Reference in New Issue
Block a user