mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-06 02:07:49 +08:00
Fix: upload image files (#13071)
### What problem does this PR solve? Fix: upload image files ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -180,10 +180,24 @@ class DialogService(CommonService):
|
||||
|
||||
|
||||
async def async_chat_solo(dialog, messages, stream=True):
|
||||
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
|
||||
attachments = ""
|
||||
image_attachments = []
|
||||
image_files = []
|
||||
if "files" in messages[-1]:
|
||||
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
if llm_type == "chat":
|
||||
text_attachments, image_attachments = split_file_attachments(messages[-1]["files"])
|
||||
else:
|
||||
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
|
||||
attachments = "\n\n".join(text_attachments)
|
||||
|
||||
if llm_type == "image2text":
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
factory = llm_model_config.get("llm_factory", "") if llm_model_config else ""
|
||||
|
||||
if llm_type == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
@ -195,8 +209,13 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
if llm_type == "chat" and image_attachments:
|
||||
convert_last_user_msg_to_multimodal(msg, image_attachments, factory)
|
||||
if stream:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
if llm_type == "chat":
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
else:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files)
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
@ -204,7 +223,10 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
continue
|
||||
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
if llm_type == "chat":
|
||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||||
@ -235,6 +257,120 @@ def get_models(dialog):
|
||||
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
|
||||
|
||||
|
||||
def split_file_attachments(files: list[dict] | None, raw: bool = False) -> tuple[list[str], list[str] | list[dict]]:
|
||||
if not files:
|
||||
return [], []
|
||||
|
||||
text_attachments = []
|
||||
if raw:
|
||||
file_contents, image_files = FileService.get_files(files, raw=True)
|
||||
for content in file_contents:
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
text_attachments.append(content)
|
||||
return text_attachments, image_files
|
||||
|
||||
image_attachments = []
|
||||
for content in FileService.get_files(files, raw=False):
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
if content.strip().startswith("data:"):
|
||||
image_attachments.append(content.strip())
|
||||
continue
|
||||
text_attachments.append(content)
|
||||
return text_attachments, image_attachments
|
||||
|
||||
|
||||
_DATA_URI_RE = re.compile(r"^data:(?P<mime>[^;]+);base64,(?P<b64>[A-Za-z0-9+/=\s]+)$")
|
||||
|
||||
|
||||
def _parse_data_uri_or_b64(s: str, default_mime: str = "image/png") -> tuple[str, str]:
|
||||
s = (s or "").strip()
|
||||
match = _DATA_URI_RE.match(s)
|
||||
if match:
|
||||
mime = match.group("mime").strip()
|
||||
b64 = match.group("b64").strip()
|
||||
return mime, b64
|
||||
return default_mime, s
|
||||
|
||||
|
||||
def _normalize_text_from_content(content) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
texts = []
|
||||
for blk in content:
|
||||
if isinstance(blk, dict):
|
||||
if blk.get("type") in {"text", "input_text"}:
|
||||
txt = blk.get("text")
|
||||
if txt:
|
||||
texts.append(str(txt))
|
||||
elif "text" in blk and isinstance(blk.get("text"), (str, int, float)):
|
||||
texts.append(str(blk["text"]))
|
||||
return "\n".join(texts).strip()
|
||||
return str(content)
|
||||
|
||||
|
||||
def convert_last_user_msg_to_multimodal(msg: list[dict], image_data_uris: list[str], factory: str) -> None:
|
||||
if not msg or not image_data_uris:
|
||||
return
|
||||
|
||||
factory_norm = (factory or "").strip().lower()
|
||||
|
||||
for idx in range(len(msg) - 1, -1, -1):
|
||||
if msg[idx].get("role") != "user":
|
||||
continue
|
||||
|
||||
original_content = msg[idx].get("content", "")
|
||||
text = _normalize_text_from_content(original_content)
|
||||
|
||||
if factory_norm == "gemini":
|
||||
parts = []
|
||||
if text:
|
||||
parts.append({"text": text})
|
||||
for image in image_data_uris:
|
||||
mime, b64 = _parse_data_uri_or_b64(str(image), default_mime="image/png")
|
||||
parts.append({"inline_data": {"mime_type": mime, "data": b64}})
|
||||
msg[idx]["content"] = parts
|
||||
return
|
||||
|
||||
if factory_norm == "anthropic":
|
||||
blocks = []
|
||||
if text:
|
||||
blocks.append({"type": "text", "text": text})
|
||||
for image in image_data_uris:
|
||||
mime, b64 = _parse_data_uri_or_b64(str(image), default_mime="image/png")
|
||||
blocks.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": mime, "data": b64},
|
||||
}
|
||||
)
|
||||
msg[idx]["content"] = blocks
|
||||
return
|
||||
|
||||
multimodal_content = []
|
||||
if isinstance(original_content, list):
|
||||
multimodal_content = deepcopy(original_content)
|
||||
else:
|
||||
text_content = "" if original_content is None else str(original_content)
|
||||
if text_content:
|
||||
multimodal_content.append({"type": "text", "text": text_content})
|
||||
|
||||
for data_uri in image_data_uris:
|
||||
image_url = data_uri
|
||||
if not isinstance(image_url, str):
|
||||
image_url = str(image_url)
|
||||
if not image_url.startswith("data:"):
|
||||
image_url = f"data:image/png;base64,{image_url}"
|
||||
multimodal_content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
|
||||
msg[idx]["content"] = multimodal_content
|
||||
return
|
||||
|
||||
|
||||
BAD_CITATION_PATTERNS = [
|
||||
re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12)
|
||||
re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12]
|
||||
@ -281,12 +417,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
return
|
||||
|
||||
chat_start_ts = timer()
|
||||
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
|
||||
if llm_type == "image2text":
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
factory = llm_model_config.get("llm_factory", "") if llm_model_config else ""
|
||||
max_tokens = llm_model_config.get("max_tokens", 8192)
|
||||
|
||||
check_llm_ts = timer()
|
||||
@ -316,10 +453,16 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||
attachments_= ""
|
||||
image_attachments = []
|
||||
image_files = []
|
||||
if "doc_ids" in messages[-1]:
|
||||
attachments = messages[-1]["doc_ids"]
|
||||
if "files" in messages[-1]:
|
||||
attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||
if llm_type == "chat":
|
||||
text_attachments, image_attachments = split_file_attachments(messages[-1]["files"])
|
||||
else:
|
||||
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
|
||||
attachments_ = "\n\n".join(text_attachments)
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
@ -464,6 +607,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
prompt4citation = citation_prompt()
|
||||
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
|
||||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
|
||||
if llm_type == "chat" and image_attachments:
|
||||
convert_last_user_msg_to_multimodal(msg, image_attachments, factory)
|
||||
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
||||
prompt = msg[0]["content"]
|
||||
|
||||
@ -555,7 +700,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
)
|
||||
|
||||
if stream:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
if llm_type == "chat":
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
else:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf, images=image_files)
|
||||
last_state = None
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
last_state = state
|
||||
@ -572,7 +720,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
final["answer"] = ""
|
||||
yield final
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
if llm_type == "chat":
|
||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf, images=image_files)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||
res = decorate_answer(answer)
|
||||
|
||||
@ -663,7 +663,7 @@ class FileService(CommonService):
|
||||
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)
|
||||
|
||||
@staticmethod
|
||||
def get_files(files: Union[None, list[dict]]) -> list[str]:
|
||||
def get_files(files: Union[None, list[dict]], raw: bool = False) -> Union[list[str], tuple[list[str], list[dict]]]:
|
||||
if not files:
|
||||
return []
|
||||
def image_to_base64(file):
|
||||
@ -671,10 +671,17 @@ class FileService(CommonService):
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
threads = []
|
||||
imgs = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
if raw:
|
||||
imgs.append(FileService.get_blob(file["created_by"], file["id"]))
|
||||
else:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
continue
|
||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return [th.result() for th in threads]
|
||||
|
||||
|
||||
if raw:
|
||||
return [th.result() for th in threads], imgs
|
||||
else:
|
||||
return [th.result() for th in threads]
|
||||
|
||||
Reference in New Issue
Block a user