mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-04 01:07:48 +08:00
Refa: improve model verification ux (#13392)
### What problem does this PR solve? Improve model verification UX. #13395 ### Type of change - [x] Refactoring --------- Co-authored-by: Liu An <asiro@qq.com>
This commit is contained in:
@ -94,17 +94,21 @@ async def set_api_key():
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=base_url, **extra)
|
||||
try:
|
||||
m, tc = await asyncio.wait_for(
|
||||
mdl.async_chat(
|
||||
async def check_streamly():
|
||||
async for chunk in mdl.async_chat_streamly(
|
||||
None,
|
||||
[{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9, "max_tokens": 50},
|
||||
),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if m.find("**ERROR**") >= 0:
|
||||
raise Exception(m)
|
||||
chat_passed = True
|
||||
[{"role": "user", "content": "Hi"}],
|
||||
{"temperature": 0.9},
|
||||
):
|
||||
if chunk and isinstance(chunk, str) and chunk.find("**ERROR**") < 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds)
|
||||
if result:
|
||||
chat_passed = True
|
||||
else:
|
||||
raise Exception("No valid response received")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
|
||||
elif not rerank_passed and llm.model_type == LLMType.RERANK.value:
|
||||
@ -127,7 +131,7 @@ async def set_api_key():
|
||||
|
||||
if req.get("verify", False):
|
||||
return get_json_result(data={"message": msg, "success": len(msg.strip())==0})
|
||||
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
@ -260,16 +264,19 @@ async def add_llm():
|
||||
**extra,
|
||||
)
|
||||
try:
|
||||
m, tc = await asyncio.wait_for(
|
||||
mdl.async_chat(
|
||||
async def check_streamly():
|
||||
async for chunk in mdl.async_chat_streamly(
|
||||
None,
|
||||
[{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
[{"role": "user", "content": "Hi"}],
|
||||
{"temperature": 0.9},
|
||||
),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
):
|
||||
if chunk and isinstance(chunk, str) and chunk.find("**ERROR**:") < 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
result = await asyncio.wait_for(check_streamly(), timeout=timeout_seconds)
|
||||
if not result:
|
||||
raise Exception("No valid response received")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
|
||||
@ -339,7 +346,7 @@ async def add_llm():
|
||||
|
||||
if req.get("verify", False):
|
||||
return get_json_result(data={"message": msg, "success": len(msg.strip()) == 0})
|
||||
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user