Address API review follow-ups

This commit is contained in:
Yanli 盐粒
2026-03-18 18:31:09 +08:00
parent db7d5e30cb
commit a0017183b6
13 changed files with 123 additions and 26 deletions

View File

@ -225,7 +225,7 @@ class Vector:
start = time.time()
logger.info("start embedding %s texts %s", len(texts), start)
batch_size = 1000
total_batches = len(texts) + batch_size - 1
total_batches = (len(texts) + batch_size - 1) // batch_size
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
batch_start = time.time()
@ -244,7 +244,7 @@ class Vector:
start = time.time()
logger.info("start embedding %s files %s", len(file_documents), start)
batch_size = 1000
total_batches = len(file_documents) + batch_size - 1
total_batches = (len(file_documents) + batch_size - 1) // batch_size
for i in range(0, len(file_documents), batch_size):
batch = file_documents[i : i + batch_size]
batch_start = time.time()

View File

@ -2,7 +2,7 @@ import os
import re
import tempfile
from pathlib import Path
from typing import TypeAlias, cast
from typing import TypeAlias
from urllib.parse import unquote
from configs import dify_config
@ -120,8 +120,11 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
pdf_upload_file = cast(UploadFile, upload_file)
extractor = PdfExtractor(file_path, pdf_upload_file.tenant_id, pdf_upload_file.created_by)
extractor = PdfExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
@ -131,8 +134,11 @@ class ExtractProcessor:
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
docx_upload_file = cast(UploadFile, upload_file)
extractor = WordExtractor(file_path, docx_upload_file.tenant_id, docx_upload_file.created_by)
extractor = WordExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
elif file_extension == ".doc":
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
elif file_extension == ".csv":
@ -158,15 +164,21 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
pdf_upload_file = cast(UploadFile, upload_file)
extractor = PdfExtractor(file_path, pdf_upload_file.tenant_id, pdf_upload_file.created_by)
extractor = PdfExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
docx_upload_file = cast(UploadFile, upload_file)
extractor = WordExtractor(file_path, docx_upload_file.tenant_id, docx_upload_file.created_by)
extractor = WordExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == ".epub":

View File

@ -47,7 +47,7 @@ class PdfExtractor(BaseExtractor):
]
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
def __init__(self, file_path: str, tenant_id: str | None, user_id: str | None, file_cache_key: str | None = None):
"""Initialize PdfExtractor."""
self._file_path = file_path
self._tenant_id = tenant_id
@ -114,6 +114,9 @@ class PdfExtractor(BaseExtractor):
"""
image_content = []
upload_files = []
if not self._tenant_id or not self._user_id:
return ""
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
try:

View File

@ -35,7 +35,7 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(self, file_path: str, tenant_id: str, user_id: str):
def __init__(self, file_path: str, tenant_id: str | None, user_id: str | None):
"""Initialize with file path."""
self.file_path = file_path
self.tenant_id = tenant_id
@ -85,6 +85,9 @@ class WordExtractor(BaseExtractor):
return bool(parsed.netloc) and bool(parsed.scheme)
def _extract_images_from_docx(self, doc):
if not self.tenant_id or not self.user_id:
return {}
image_count = 0
image_map = {}
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL

View File

@ -14,7 +14,7 @@ class ApiKeyAuthFactory:
@staticmethod
def get_apikey_auth_factory(provider: AuthProvider) -> type[ApiKeyAuthBase]:
match provider:
match ApiKeyAuthFactory._normalize_provider(provider):
case AuthType.FIRECRAWL:
from services.auth.firecrawl.firecrawl import FirecrawlAuth
@ -29,3 +29,13 @@ class ApiKeyAuthFactory:
return JinaAuth
case _:
raise ValueError("Invalid provider")
@staticmethod
def _normalize_provider(provider: AuthProvider) -> AuthType | str:
if isinstance(provider, AuthType):
return provider
try:
return AuthType(provider)
except ValueError:
return provider

View File

@ -2,12 +2,12 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def __init__(self, credentials: ApiKeyAuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")

View File

@ -2,12 +2,12 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def __init__(self, credentials: ApiKeyAuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")

View File

@ -2,12 +2,12 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def __init__(self, credentials: ApiKeyAuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")

View File

@ -3,12 +3,12 @@ from urllib.parse import urljoin
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def __init__(self, credentials: ApiKeyAuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":
raise ValueError("Invalid auth type, WaterCrawl auth type must be x-api-key")

View File

@ -200,6 +200,26 @@ class TestExtractProcessorFileRouting:
with pytest.raises(AssertionError, match="upload_file is required"):
ExtractProcessor.extract(setting)
@pytest.mark.parametrize(
("extension", "expected_extractor"),
[
(".pdf", "PdfExtractor"),
(".docx", "WordExtractor"),
],
)
def test_extract_routes_url_based_files_without_upload_context(self, monkeypatch, extension, expected_extractor):
factory = _patch_all_extractors(monkeypatch)
monkeypatch.setattr(processor_module.dify_config, "ETL_TYPE", "SelfHosted")
setting = SimpleNamespace(datasource_type=DatasourceType.FILE, upload_file=None)
docs = ExtractProcessor.extract(setting, file_path=f"/tmp/from-url{extension}")
assert len(docs) == 1
assert docs[0].page_content == f"extracted-by-{expected_extractor}"
assert factory.calls[-1][0] == expected_extractor
assert factory.calls[-1][1] == (f"/tmp/from-url{extension}", None, None)
class TestExtractProcessorDatasourceRouting:
def test_extract_routes_notion_datasource(self, monkeypatch):

View File

@ -122,6 +122,23 @@ def test_extract_images_get_objects_scenarios(mock_dependencies, get_objects_sid
assert result == ""
def test_extract_images_skips_persistence_without_upload_context(mock_dependencies):
mock_page = MagicMock()
mock_image_obj = MagicMock()
mock_page.get_objects.return_value = [mock_image_obj]
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id=None, user_id=None)
with patch("pypdfium2.raw", autospec=True) as mock_raw:
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
result = extractor._extract_images(mock_page)
assert result == ""
assert mock_dependencies.saves == []
assert mock_dependencies.db.session.added == []
assert mock_dependencies.db.session.committed is False
def test_extract_calls_extract_images(mock_dependencies, monkeypatch):
# Mock pypdfium2
mock_pdf_doc = MagicMock()

View File

@ -210,6 +210,35 @@ def test_extract_images_from_docx_uses_internal_files_url():
dify_config.INTERNAL_FILES_URL = original_internal_files_url
def test_extract_images_from_docx_skips_persistence_without_upload_context(monkeypatch):
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: None))
class DummySession:
def __init__(self):
self.added = []
self.committed = False
def add(self, obj):
self.added.append(obj)
def commit(self):
self.committed = True
db_stub = SimpleNamespace(session=DummySession())
monkeypatch.setattr(we, "db", db_stub)
rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png")
doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext}))
extractor = object.__new__(WordExtractor)
extractor.tenant_id = None
extractor.user_id = None
assert extractor._extract_images_from_docx(doc) == {}
assert db_stub.session.added == []
assert db_stub.session.committed is False
def test_extract_hyperlinks(monkeypatch):
# Mock db and storage to avoid issues during image extraction (even if no images are present)
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))

View File

@ -13,8 +13,11 @@ class TestApiKeyAuthFactory:
("provider", "auth_class_path"),
[
(AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
(AuthType.FIRECRAWL.value, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
(AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
(AuthType.WATERCRAWL.value, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
(AuthType.JINA, "services.auth.jina.jina.JinaAuth"),
(AuthType.JINA.value, "services.auth.jina.jina.JinaAuth"),
],
)
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):