mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-05 09:47:47 +08:00
Feat: Support siliconflow.com (#13308)
### What problem does this PR solve? Feat: Support siliconflow.com ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -34,7 +34,7 @@ from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, O
|
||||
def factories():
|
||||
try:
|
||||
fac = get_allowed_llm_factories()
|
||||
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI", "Builtin"]]
|
||||
fac = [f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI", "Builtin", "siliconflow_intl"]]
|
||||
llms = LLMService.get_all()
|
||||
mdl_types = {}
|
||||
for m in llms:
|
||||
@ -64,13 +64,22 @@ async def set_api_key():
|
||||
# test if api key works
|
||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||
factory = req["llm_factory"]
|
||||
base_url = req.get("base_url", "")
|
||||
source_factory = req.get("source_fid", factory)
|
||||
extra = {"provider": factory}
|
||||
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
||||
source_llms = list(LLMService.query(fid=source_factory))
|
||||
if not source_llms:
|
||||
msg = f"No models configured for {factory} (source: {source_factory})."
|
||||
if req.get("verify", False):
|
||||
return get_json_result(data={"message": msg, "success": False})
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
msg = ""
|
||||
for llm in LLMService.query(fid=factory):
|
||||
for llm in source_llms:
|
||||
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=base_url)
|
||||
try:
|
||||
arr, tc = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
|
||||
@ -83,7 +92,7 @@ async def set_api_key():
|
||||
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e)
|
||||
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=base_url, **extra)
|
||||
try:
|
||||
m, tc = await asyncio.wait_for(
|
||||
mdl.async_chat(
|
||||
@ -100,7 +109,7 @@ async def set_api_key():
|
||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK.value:
|
||||
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
|
||||
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=base_url)
|
||||
try:
|
||||
arr, tc = await asyncio.wait_for(
|
||||
asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]),
|
||||
@ -122,12 +131,12 @@ async def set_api_key():
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
llm_config = {"api_key": req["api_key"], "api_base": req.get("base_url", "")}
|
||||
llm_config = {"api_key": req["api_key"], "api_base": base_url}
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req:
|
||||
llm_config[n] = req[n]
|
||||
|
||||
for llm in LLMService.query(fid=factory):
|
||||
for llm in source_llms:
|
||||
llm_config["max_tokens"] = llm.max_tokens
|
||||
if not TenantLLMService.filter_update([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory, TenantLLM.llm_name == llm.llm_name], llm_config):
|
||||
TenantLLMService.save(
|
||||
|
||||
Reference in New Issue
Block a user