mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-05 01:37:46 +08:00
Fix: upload image files (#13071)
### What problem does this PR solve? Fix: upload image files ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -125,23 +125,118 @@ class LLM(ComponentBase):
|
||||
msg.append(p)
|
||||
return msg, self.string_format(self._param.sys_prompt, args)
|
||||
|
||||
def _prepare_prompt_variables(self):
|
||||
if self._param.visual_files_var:
|
||||
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
|
||||
if not self.imgs:
|
||||
self.imgs = []
|
||||
self.imgs = [img for img in self.imgs if img[:len("data:image/")] == "data:image/"]
|
||||
if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value:
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value,
|
||||
self._param.llm_id, max_retries=self._param.max_retries,
|
||||
retry_interval=self._param.delay_after_error
|
||||
)
|
||||
@staticmethod
|
||||
def _extract_data_images(value) -> list[str]:
|
||||
imgs = []
|
||||
|
||||
def walk(v):
|
||||
if v is None:
|
||||
return
|
||||
if isinstance(v, str):
|
||||
v = v.strip()
|
||||
if v.startswith("data:image/"):
|
||||
imgs.append(v)
|
||||
return
|
||||
if isinstance(v, (list, tuple, set)):
|
||||
for item in v:
|
||||
walk(item)
|
||||
return
|
||||
if isinstance(v, dict):
|
||||
if "content" in v:
|
||||
walk(v.get("content"))
|
||||
else:
|
||||
for item in v.values():
|
||||
walk(item)
|
||||
|
||||
walk(value)
|
||||
return imgs
|
||||
|
||||
@staticmethod
|
||||
def _uniq_images(images: list[str]) -> list[str]:
|
||||
seen = set()
|
||||
uniq = []
|
||||
for img in images:
|
||||
if not isinstance(img, str):
|
||||
continue
|
||||
if not img.startswith("data:image/"):
|
||||
continue
|
||||
if img in seen:
|
||||
continue
|
||||
seen.add(img)
|
||||
uniq.append(img)
|
||||
return uniq
|
||||
|
||||
@classmethod
|
||||
def _remove_data_images(cls, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
return None if value.strip().startswith("data:image/") else value
|
||||
|
||||
if isinstance(value, list):
|
||||
cleaned = []
|
||||
for item in value:
|
||||
v = cls._remove_data_images(item)
|
||||
if v is None:
|
||||
continue
|
||||
if isinstance(v, (list, tuple, set, dict)) and not v:
|
||||
continue
|
||||
cleaned.append(v)
|
||||
return cleaned
|
||||
|
||||
if isinstance(value, tuple):
|
||||
cleaned = []
|
||||
for item in value:
|
||||
v = cls._remove_data_images(item)
|
||||
if v is None:
|
||||
continue
|
||||
if isinstance(v, (list, tuple, set, dict)) and not v:
|
||||
continue
|
||||
cleaned.append(v)
|
||||
return tuple(cleaned)
|
||||
|
||||
if isinstance(value, set):
|
||||
cleaned = []
|
||||
for item in value:
|
||||
v = cls._remove_data_images(item)
|
||||
if v is None:
|
||||
continue
|
||||
if isinstance(v, (list, tuple, set, dict)) and not v:
|
||||
continue
|
||||
cleaned.append(v)
|
||||
return cleaned
|
||||
|
||||
if isinstance(value, dict):
|
||||
if value.get("type") in {"image_url", "input_image", "image"} and cls._extract_data_images(value):
|
||||
return None
|
||||
|
||||
cleaned = {}
|
||||
for k, item in value.items():
|
||||
v = cls._remove_data_images(item)
|
||||
if v is None:
|
||||
continue
|
||||
if isinstance(v, (list, tuple, set, dict)) and not v:
|
||||
continue
|
||||
cleaned[k] = v
|
||||
return cleaned
|
||||
|
||||
return value
|
||||
|
||||
def _prepare_prompt_variables(self):
|
||||
self.imgs = []
|
||||
if self._param.visual_files_var:
|
||||
self.imgs.extend(self._extract_data_images(self._canvas.get_variable_value(self._param.visual_files_var)))
|
||||
|
||||
args = {}
|
||||
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
|
||||
extracted_imgs = []
|
||||
for k, o in vars.items():
|
||||
args[k] = o["value"]
|
||||
raw_value = o["value"]
|
||||
extracted_imgs.extend(self._extract_data_images(raw_value))
|
||||
args[k] = self._remove_data_images(raw_value)
|
||||
if args[k] is None:
|
||||
args[k] = ""
|
||||
if not isinstance(args[k], str):
|
||||
try:
|
||||
args[k] = json.dumps(args[k], ensure_ascii=False)
|
||||
@ -149,6 +244,13 @@ class LLM(ComponentBase):
|
||||
args[k] = str(args[k])
|
||||
self.set_input_value(k, args[k])
|
||||
|
||||
self.imgs = self._uniq_images(self.imgs + extracted_imgs)
|
||||
if self.imgs and TenantLLMService.llm_id2llm_type(self._param.llm_id) == LLMType.CHAT.value:
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT.value,
|
||||
self._param.llm_id, max_retries=self._param.max_retries,
|
||||
retry_interval=self._param.delay_after_error
|
||||
)
|
||||
|
||||
msg, sys_prompt = self._sys_prompt_and_msg(self._canvas.get_history(self._param.message_history_window_size)[:-1], args)
|
||||
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
|
||||
if self._param.cite and self._canvas.get_reference()["chunks"]:
|
||||
|
||||
@ -180,10 +180,24 @@ class DialogService(CommonService):
|
||||
|
||||
|
||||
async def async_chat_solo(dialog, messages, stream=True):
|
||||
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
|
||||
attachments = ""
|
||||
image_attachments = []
|
||||
image_files = []
|
||||
if "files" in messages[-1]:
|
||||
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
if llm_type == "chat":
|
||||
text_attachments, image_attachments = split_file_attachments(messages[-1]["files"])
|
||||
else:
|
||||
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
|
||||
attachments = "\n\n".join(text_attachments)
|
||||
|
||||
if llm_type == "image2text":
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
factory = llm_model_config.get("llm_factory", "") if llm_model_config else ""
|
||||
|
||||
if llm_type == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
@ -195,8 +209,13 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
|
||||
if attachments and msg:
|
||||
msg[-1]["content"] += attachments
|
||||
if llm_type == "chat" and image_attachments:
|
||||
convert_last_user_msg_to_multimodal(msg, image_attachments, factory)
|
||||
if stream:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
if llm_type == "chat":
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
else:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files)
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
if kind == "marker":
|
||||
flags = {"start_to_think": True} if value == "<think>" else {"end_to_think": True}
|
||||
@ -204,7 +223,10 @@ async def async_chat_solo(dialog, messages, stream=True):
|
||||
continue
|
||||
yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False}
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
if llm_type == "chat":
|
||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting, images=image_files)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||||
@ -235,6 +257,120 @@ def get_models(dialog):
|
||||
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
|
||||
|
||||
|
||||
def split_file_attachments(files: list[dict] | None, raw: bool = False) -> tuple[list[str], list[str] | list[dict]]:
|
||||
if not files:
|
||||
return [], []
|
||||
|
||||
text_attachments = []
|
||||
if raw:
|
||||
file_contents, image_files = FileService.get_files(files, raw=True)
|
||||
for content in file_contents:
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
text_attachments.append(content)
|
||||
return text_attachments, image_files
|
||||
|
||||
image_attachments = []
|
||||
for content in FileService.get_files(files, raw=False):
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
if content.strip().startswith("data:"):
|
||||
image_attachments.append(content.strip())
|
||||
continue
|
||||
text_attachments.append(content)
|
||||
return text_attachments, image_attachments
|
||||
|
||||
|
||||
_DATA_URI_RE = re.compile(r"^data:(?P<mime>[^;]+);base64,(?P<b64>[A-Za-z0-9+/=\s]+)$")
|
||||
|
||||
|
||||
def _parse_data_uri_or_b64(s: str, default_mime: str = "image/png") -> tuple[str, str]:
|
||||
s = (s or "").strip()
|
||||
match = _DATA_URI_RE.match(s)
|
||||
if match:
|
||||
mime = match.group("mime").strip()
|
||||
b64 = match.group("b64").strip()
|
||||
return mime, b64
|
||||
return default_mime, s
|
||||
|
||||
|
||||
def _normalize_text_from_content(content) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
texts = []
|
||||
for blk in content:
|
||||
if isinstance(blk, dict):
|
||||
if blk.get("type") in {"text", "input_text"}:
|
||||
txt = blk.get("text")
|
||||
if txt:
|
||||
texts.append(str(txt))
|
||||
elif "text" in blk and isinstance(blk.get("text"), (str, int, float)):
|
||||
texts.append(str(blk["text"]))
|
||||
return "\n".join(texts).strip()
|
||||
return str(content)
|
||||
|
||||
|
||||
def convert_last_user_msg_to_multimodal(msg: list[dict], image_data_uris: list[str], factory: str) -> None:
|
||||
if not msg or not image_data_uris:
|
||||
return
|
||||
|
||||
factory_norm = (factory or "").strip().lower()
|
||||
|
||||
for idx in range(len(msg) - 1, -1, -1):
|
||||
if msg[idx].get("role") != "user":
|
||||
continue
|
||||
|
||||
original_content = msg[idx].get("content", "")
|
||||
text = _normalize_text_from_content(original_content)
|
||||
|
||||
if factory_norm == "gemini":
|
||||
parts = []
|
||||
if text:
|
||||
parts.append({"text": text})
|
||||
for image in image_data_uris:
|
||||
mime, b64 = _parse_data_uri_or_b64(str(image), default_mime="image/png")
|
||||
parts.append({"inline_data": {"mime_type": mime, "data": b64}})
|
||||
msg[idx]["content"] = parts
|
||||
return
|
||||
|
||||
if factory_norm == "anthropic":
|
||||
blocks = []
|
||||
if text:
|
||||
blocks.append({"type": "text", "text": text})
|
||||
for image in image_data_uris:
|
||||
mime, b64 = _parse_data_uri_or_b64(str(image), default_mime="image/png")
|
||||
blocks.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": mime, "data": b64},
|
||||
}
|
||||
)
|
||||
msg[idx]["content"] = blocks
|
||||
return
|
||||
|
||||
multimodal_content = []
|
||||
if isinstance(original_content, list):
|
||||
multimodal_content = deepcopy(original_content)
|
||||
else:
|
||||
text_content = "" if original_content is None else str(original_content)
|
||||
if text_content:
|
||||
multimodal_content.append({"type": "text", "text": text_content})
|
||||
|
||||
for data_uri in image_data_uris:
|
||||
image_url = data_uri
|
||||
if not isinstance(image_url, str):
|
||||
image_url = str(image_url)
|
||||
if not image_url.startswith("data:"):
|
||||
image_url = f"data:image/png;base64,{image_url}"
|
||||
multimodal_content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
|
||||
msg[idx]["content"] = multimodal_content
|
||||
return
|
||||
|
||||
|
||||
BAD_CITATION_PATTERNS = [
|
||||
re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12)
|
||||
re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12]
|
||||
@ -281,12 +417,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
return
|
||||
|
||||
chat_start_ts = timer()
|
||||
|
||||
if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
llm_type = TenantLLMService.llm_id2llm_type(dialog.llm_id)
|
||||
if llm_type == "image2text":
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
factory = llm_model_config.get("llm_factory", "") if llm_model_config else ""
|
||||
max_tokens = llm_model_config.get("max_tokens", 8192)
|
||||
|
||||
check_llm_ts = timer()
|
||||
@ -316,10 +453,16 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
|
||||
attachments_= ""
|
||||
image_attachments = []
|
||||
image_files = []
|
||||
if "doc_ids" in messages[-1]:
|
||||
attachments = messages[-1]["doc_ids"]
|
||||
if "files" in messages[-1]:
|
||||
attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||
if llm_type == "chat":
|
||||
text_attachments, image_attachments = split_file_attachments(messages[-1]["files"])
|
||||
else:
|
||||
text_attachments, image_files = split_file_attachments(messages[-1]["files"], raw=True)
|
||||
attachments_ = "\n\n".join(text_attachments)
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
@ -464,6 +607,8 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
prompt4citation = citation_prompt()
|
||||
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
|
||||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
|
||||
if llm_type == "chat" and image_attachments:
|
||||
convert_last_user_msg_to_multimodal(msg, image_attachments, factory)
|
||||
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
||||
prompt = msg[0]["content"]
|
||||
|
||||
@ -555,7 +700,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
)
|
||||
|
||||
if stream:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
if llm_type == "chat":
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
else:
|
||||
stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf, images=image_files)
|
||||
last_state = None
|
||||
async for kind, value, state in _stream_with_think_delta(stream_iter):
|
||||
last_state = state
|
||||
@ -572,7 +720,10 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
final["answer"] = ""
|
||||
yield final
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
if llm_type == "chat":
|
||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||
else:
|
||||
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf, images=image_files)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||
res = decorate_answer(answer)
|
||||
|
||||
@ -663,7 +663,7 @@ class FileService(CommonService):
|
||||
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)
|
||||
|
||||
@staticmethod
|
||||
def get_files(files: Union[None, list[dict]]) -> list[str]:
|
||||
def get_files(files: Union[None, list[dict]], raw: bool = False) -> Union[list[str], tuple[list[str], list[dict]]]:
|
||||
if not files:
|
||||
return []
|
||||
def image_to_base64(file):
|
||||
@ -671,10 +671,17 @@ class FileService(CommonService):
|
||||
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
|
||||
exe = ThreadPoolExecutor(max_workers=5)
|
||||
threads = []
|
||||
imgs = []
|
||||
for file in files:
|
||||
if file["mime_type"].find("image") >=0:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
if raw:
|
||||
imgs.append(FileService.get_blob(file["created_by"], file["id"]))
|
||||
else:
|
||||
threads.append(exe.submit(image_to_base64, file))
|
||||
continue
|
||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||
return [th.result() for th in threads]
|
||||
|
||||
|
||||
if raw:
|
||||
return [th.result() for th in threads], imgs
|
||||
else:
|
||||
return [th.result() for th in threads]
|
||||
|
||||
@ -51,8 +51,9 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
||||
}
|
||||
)
|
||||
cv_mdl = LLMBundle(tenant_id, llm_type=LLMType.IMAGE2TEXT, lang=lang)
|
||||
video_prompt = str(parser_config.get("video_prompt", "") or "")
|
||||
ans = asyncio.run(
|
||||
cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename))
|
||||
cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename, video_prompt=video_prompt))
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
ans += "\n" + ans
|
||||
tokenize(doc, ans, eng)
|
||||
|
||||
@ -161,6 +161,7 @@ class ParserParam(ProcessParamBase):
|
||||
"mkv",
|
||||
],
|
||||
"output_format": "text",
|
||||
"prompt": "",
|
||||
},
|
||||
}
|
||||
|
||||
@ -685,7 +686,8 @@ class Parser(ProcessBase):
|
||||
self.set_output("output_format", conf["output_format"])
|
||||
|
||||
cv_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.IMAGE2TEXT, llm_name=conf["llm_id"])
|
||||
txt = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name))
|
||||
video_prompt = str(conf.get("prompt", "") or "")
|
||||
txt = asyncio.run(cv_mdl.async_chat(system="", history=[], gen_conf={}, video_bytes=blob, filename=name, video_prompt=video_prompt))
|
||||
|
||||
self.set_output("text", txt)
|
||||
|
||||
|
||||
@ -67,6 +67,61 @@ class Base(ABC):
|
||||
hist.append(h)
|
||||
return hist
|
||||
|
||||
@staticmethod
|
||||
def _blob_to_data_url(blob, mime_type="image/png"):
|
||||
if isinstance(blob, str):
|
||||
blob = blob.strip()
|
||||
if blob.startswith("data:") or blob.startswith("http://") or blob.startswith("https://") or blob.startswith("file://"):
|
||||
return blob
|
||||
return f"data:{mime_type};base64,{blob}"
|
||||
if isinstance(blob, BytesIO):
|
||||
blob = blob.getvalue()
|
||||
if isinstance(blob, memoryview):
|
||||
blob = blob.tobytes()
|
||||
if isinstance(blob, bytearray):
|
||||
blob = bytes(blob)
|
||||
if isinstance(blob, bytes):
|
||||
b64 = base64.b64encode(blob).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{b64}"
|
||||
return None
|
||||
|
||||
def _normalize_image(self, image):
|
||||
if isinstance(image, dict):
|
||||
inline_data = image.get("inline_data")
|
||||
if isinstance(inline_data, dict):
|
||||
mime = inline_data.get("mime_type") or "image/png"
|
||||
data_url = self._blob_to_data_url(inline_data.get("data"), mime)
|
||||
if data_url:
|
||||
return data_url
|
||||
|
||||
image_url = image.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
data_url = self._blob_to_data_url(image_url.get("url"), image.get("mime_type") or "image/png")
|
||||
if data_url:
|
||||
return data_url
|
||||
if isinstance(image_url, str):
|
||||
data_url = self._blob_to_data_url(image_url, image.get("mime_type") or "image/png")
|
||||
if data_url:
|
||||
return data_url
|
||||
|
||||
if "url" in image:
|
||||
data_url = self._blob_to_data_url(image.get("url"), image.get("mime_type") or "image/png")
|
||||
if data_url:
|
||||
return data_url
|
||||
|
||||
mime = image.get("mime_type") or image.get("media_type") or "image/png"
|
||||
for key in ("blob", "data"):
|
||||
if key in image:
|
||||
data_url = self._blob_to_data_url(image.get(key), mime)
|
||||
if data_url:
|
||||
return data_url
|
||||
|
||||
if isinstance(image, (bytes, bytearray, memoryview, BytesIO)):
|
||||
return self.image2base64(image)
|
||||
if isinstance(image, str):
|
||||
return self._blob_to_data_url(image, "image/png")
|
||||
return self.image2base64(image)
|
||||
|
||||
def _image_prompt(self, text, images):
|
||||
if not images:
|
||||
return text
|
||||
@ -76,7 +131,11 @@ class Base(ABC):
|
||||
|
||||
pmpt = [{"type": "text", "text": text}]
|
||||
for img in images:
|
||||
pmpt.append({"type": "image_url", "image_url": {"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"}})
|
||||
try:
|
||||
pmpt.append({"type": "image_url", "image_url": {"url": self._normalize_image(img)}})
|
||||
except Exception:
|
||||
logging.warning("[%s] Skip invalid image input in request payload.", self.__class__.__name__)
|
||||
continue
|
||||
return pmpt
|
||||
|
||||
async def async_chat(self, system, history, gen_conf, images=None, **kwargs):
|
||||
@ -248,51 +307,86 @@ class QWenCV(GptV4):
|
||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_from_content(content):
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
texts = []
|
||||
for blk in content:
|
||||
if not isinstance(blk, dict):
|
||||
continue
|
||||
if blk.get("type") in {"text", "input_text"} and blk.get("text"):
|
||||
texts.append(str(blk["text"]))
|
||||
elif "text" in blk and isinstance(blk.get("text"), (str, int, float)):
|
||||
texts.append(str(blk["text"]))
|
||||
return "\n".join(texts).strip()
|
||||
return ""
|
||||
|
||||
def _resolve_video_prompt(self, system, history, **kwargs):
|
||||
prompt = kwargs.get("video_prompt") or kwargs.get("prompt")
|
||||
if isinstance(prompt, str) and prompt.strip():
|
||||
return prompt.strip()
|
||||
|
||||
for h in reversed(history or []):
|
||||
if h.get("role") != "user":
|
||||
continue
|
||||
txt = self._extract_text_from_content(h.get("content"))
|
||||
if txt:
|
||||
return txt
|
||||
|
||||
if isinstance(system, str) and system.strip():
|
||||
return system.strip()
|
||||
|
||||
return "Please summarize this video in proper sentences."
|
||||
|
||||
async def async_chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
|
||||
if video_bytes:
|
||||
try:
|
||||
summary, summary_num_tokens = self._process_video(video_bytes, filename)
|
||||
summary, summary_num_tokens = self._process_video(video_bytes, filename, self._resolve_video_prompt(system, history, **kwargs))
|
||||
return summary, summary_num_tokens
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
return "**ERROR**: Method chat not supported yet.", 0
|
||||
return await super().async_chat(system, history, gen_conf, images=images, **kwargs)
|
||||
|
||||
def _process_video(self, video_bytes, filename):
|
||||
def _process_video(self, video_bytes, filename, prompt):
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
video_suffix = Path(filename).suffix or ".mp4"
|
||||
tmp_path = None
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
|
||||
tmp.write(video_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
video_path = f"file://{tmp_path}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"video": video_path,
|
||||
"fps": 2,
|
||||
},
|
||||
{
|
||||
"text": "Please summarize this video in proper sentences.",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
video_path = f"file://{tmp_path}"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"video": video_path,
|
||||
"fps": 2,
|
||||
},
|
||||
{
|
||||
"text": prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def call_api():
|
||||
response = MultiModalConversation.call(
|
||||
api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
)
|
||||
if response.get("message"):
|
||||
raise Exception(response["message"])
|
||||
summary = response["output"]["choices"][0]["message"].content[0]["text"]
|
||||
return summary, num_tokens_from_string(summary)
|
||||
def call_api():
|
||||
response = MultiModalConversation.call(
|
||||
api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
)
|
||||
if response.get("message"):
|
||||
raise Exception(response["message"])
|
||||
summary = response["output"]["choices"][0]["message"].content[0]["text"]
|
||||
return summary, num_tokens_from_string(summary)
|
||||
|
||||
try:
|
||||
try:
|
||||
return call_api()
|
||||
except Exception as e1:
|
||||
@ -303,6 +397,12 @@ class QWenCV(GptV4):
|
||||
return call_api()
|
||||
except Exception as e2:
|
||||
raise RuntimeError(f"Both default and intl endpoint failed.\nFirst error: {e1}\nSecond error: {e2}")
|
||||
finally:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except Exception:
|
||||
logging.warning("[QWenCV] Failed to cleanup temp video file: %s", tmp_path)
|
||||
|
||||
|
||||
class HunyuanCV(GptV4):
|
||||
|
||||
Reference in New Issue
Block a user