mirror of
https://github.com/langgenius/dify.git
synced 2026-04-20 18:57:19 +08:00
Address API review follow-ups
This commit is contained in:
@ -225,7 +225,7 @@ class Vector:
|
||||
start = time.time()
|
||||
logger.info("start embedding %s texts %s", len(texts), start)
|
||||
batch_size = 1000
|
||||
total_batches = len(texts) + batch_size - 1
|
||||
total_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
batch_start = time.time()
|
||||
@ -244,7 +244,7 @@ class Vector:
|
||||
start = time.time()
|
||||
logger.info("start embedding %s files %s", len(file_documents), start)
|
||||
batch_size = 1000
|
||||
total_batches = len(file_documents) + batch_size - 1
|
||||
total_batches = (len(file_documents) + batch_size - 1) // batch_size
|
||||
for i in range(0, len(file_documents), batch_size):
|
||||
batch = file_documents[i : i + batch_size]
|
||||
batch_start = time.time()
|
||||
|
||||
@ -2,7 +2,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TypeAlias, cast
|
||||
from typing import TypeAlias
|
||||
from urllib.parse import unquote
|
||||
|
||||
from configs import dify_config
|
||||
@ -120,8 +120,11 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
pdf_upload_file = cast(UploadFile, upload_file)
|
||||
extractor = PdfExtractor(file_path, pdf_upload_file.tenant_id, pdf_upload_file.created_by)
|
||||
extractor = PdfExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id if upload_file else None,
|
||||
upload_file.created_by if upload_file else None,
|
||||
)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
@ -131,8 +134,11 @@ class ExtractProcessor:
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension == ".docx":
|
||||
docx_upload_file = cast(UploadFile, upload_file)
|
||||
extractor = WordExtractor(file_path, docx_upload_file.tenant_id, docx_upload_file.created_by)
|
||||
extractor = WordExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id if upload_file else None,
|
||||
upload_file.created_by if upload_file else None,
|
||||
)
|
||||
elif file_extension == ".doc":
|
||||
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".csv":
|
||||
@ -158,15 +164,21 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
pdf_upload_file = cast(UploadFile, upload_file)
|
||||
extractor = PdfExtractor(file_path, pdf_upload_file.tenant_id, pdf_upload_file.created_by)
|
||||
extractor = PdfExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id if upload_file else None,
|
||||
upload_file.created_by if upload_file else None,
|
||||
)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension == ".docx":
|
||||
docx_upload_file = cast(UploadFile, upload_file)
|
||||
extractor = WordExtractor(file_path, docx_upload_file.tenant_id, docx_upload_file.created_by)
|
||||
extractor = WordExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id if upload_file else None,
|
||||
upload_file.created_by if upload_file else None,
|
||||
)
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == ".epub":
|
||||
|
||||
@ -47,7 +47,7 @@ class PdfExtractor(BaseExtractor):
|
||||
]
|
||||
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
|
||||
def __init__(self, file_path: str, tenant_id: str | None, user_id: str | None, file_cache_key: str | None = None):
|
||||
"""Initialize PdfExtractor."""
|
||||
self._file_path = file_path
|
||||
self._tenant_id = tenant_id
|
||||
@ -114,6 +114,9 @@ class PdfExtractor(BaseExtractor):
|
||||
"""
|
||||
image_content = []
|
||||
upload_files = []
|
||||
if not self._tenant_id or not self._user_id:
|
||||
return ""
|
||||
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
try:
|
||||
|
||||
@ -35,7 +35,7 @@ class WordExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str):
|
||||
def __init__(self, file_path: str, tenant_id: str | None, user_id: str | None):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
self.tenant_id = tenant_id
|
||||
@ -85,6 +85,9 @@ class WordExtractor(BaseExtractor):
|
||||
return bool(parsed.netloc) and bool(parsed.scheme)
|
||||
|
||||
def _extract_images_from_docx(self, doc):
|
||||
if not self.tenant_id or not self.user_id:
|
||||
return {}
|
||||
|
||||
image_count = 0
|
||||
image_map = {}
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
@ -14,7 +14,7 @@ class ApiKeyAuthFactory:
|
||||
|
||||
@staticmethod
|
||||
def get_apikey_auth_factory(provider: AuthProvider) -> type[ApiKeyAuthBase]:
|
||||
match provider:
|
||||
match ApiKeyAuthFactory._normalize_provider(provider):
|
||||
case AuthType.FIRECRAWL:
|
||||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
@ -29,3 +29,13 @@ class ApiKeyAuthFactory:
|
||||
return JinaAuth
|
||||
case _:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_provider(provider: AuthProvider) -> AuthType | str:
|
||||
if isinstance(provider, AuthType):
|
||||
return provider
|
||||
|
||||
try:
|
||||
return AuthType(provider)
|
||||
except ValueError:
|
||||
return provider
|
||||
|
||||
@ -2,12 +2,12 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
|
||||
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
def __init__(self, credentials: ApiKeyAuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
|
||||
|
||||
@ -2,12 +2,12 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
def __init__(self, credentials: ApiKeyAuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
|
||||
|
||||
@ -2,12 +2,12 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
def __init__(self, credentials: ApiKeyAuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer")
|
||||
|
||||
@ -3,12 +3,12 @@ from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
|
||||
|
||||
|
||||
class WatercrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
def __init__(self, credentials: ApiKeyAuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "x-api-key":
|
||||
raise ValueError("Invalid auth type, WaterCrawl auth type must be x-api-key")
|
||||
|
||||
@ -200,6 +200,26 @@ class TestExtractProcessorFileRouting:
|
||||
with pytest.raises(AssertionError, match="upload_file is required"):
|
||||
ExtractProcessor.extract(setting)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("extension", "expected_extractor"),
|
||||
[
|
||||
(".pdf", "PdfExtractor"),
|
||||
(".docx", "WordExtractor"),
|
||||
],
|
||||
)
|
||||
def test_extract_routes_url_based_files_without_upload_context(self, monkeypatch, extension, expected_extractor):
|
||||
factory = _patch_all_extractors(monkeypatch)
|
||||
monkeypatch.setattr(processor_module.dify_config, "ETL_TYPE", "SelfHosted")
|
||||
|
||||
setting = SimpleNamespace(datasource_type=DatasourceType.FILE, upload_file=None)
|
||||
|
||||
docs = ExtractProcessor.extract(setting, file_path=f"/tmp/from-url{extension}")
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == f"extracted-by-{expected_extractor}"
|
||||
assert factory.calls[-1][0] == expected_extractor
|
||||
assert factory.calls[-1][1] == (f"/tmp/from-url{extension}", None, None)
|
||||
|
||||
|
||||
class TestExtractProcessorDatasourceRouting:
|
||||
def test_extract_routes_notion_datasource(self, monkeypatch):
|
||||
|
||||
@ -122,6 +122,23 @@ def test_extract_images_get_objects_scenarios(mock_dependencies, get_objects_sid
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_extract_images_skips_persistence_without_upload_context(mock_dependencies):
|
||||
mock_page = MagicMock()
|
||||
mock_image_obj = MagicMock()
|
||||
mock_page.get_objects.return_value = [mock_image_obj]
|
||||
|
||||
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id=None, user_id=None)
|
||||
|
||||
with patch("pypdfium2.raw", autospec=True) as mock_raw:
|
||||
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
|
||||
result = extractor._extract_images(mock_page)
|
||||
|
||||
assert result == ""
|
||||
assert mock_dependencies.saves == []
|
||||
assert mock_dependencies.db.session.added == []
|
||||
assert mock_dependencies.db.session.committed is False
|
||||
|
||||
|
||||
def test_extract_calls_extract_images(mock_dependencies, monkeypatch):
|
||||
# Mock pypdfium2
|
||||
mock_pdf_doc = MagicMock()
|
||||
|
||||
@ -210,6 +210,35 @@ def test_extract_images_from_docx_uses_internal_files_url():
|
||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||
|
||||
|
||||
def test_extract_images_from_docx_skips_persistence_without_upload_context(monkeypatch):
|
||||
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: None))
|
||||
|
||||
class DummySession:
|
||||
def __init__(self):
|
||||
self.added = []
|
||||
self.committed = False
|
||||
|
||||
def add(self, obj):
|
||||
self.added.append(obj)
|
||||
|
||||
def commit(self):
|
||||
self.committed = True
|
||||
|
||||
db_stub = SimpleNamespace(session=DummySession())
|
||||
monkeypatch.setattr(we, "db", db_stub)
|
||||
|
||||
rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png")
|
||||
doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext}))
|
||||
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor.tenant_id = None
|
||||
extractor.user_id = None
|
||||
|
||||
assert extractor._extract_images_from_docx(doc) == {}
|
||||
assert db_stub.session.added == []
|
||||
assert db_stub.session.committed is False
|
||||
|
||||
|
||||
def test_extract_hyperlinks(monkeypatch):
|
||||
# Mock db and storage to avoid issues during image extraction (even if no images are present)
|
||||
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
|
||||
|
||||
@ -13,8 +13,11 @@ class TestApiKeyAuthFactory:
|
||||
("provider", "auth_class_path"),
|
||||
[
|
||||
(AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
|
||||
(AuthType.FIRECRAWL.value, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
|
||||
(AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
|
||||
(AuthType.WATERCRAWL.value, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
|
||||
(AuthType.JINA, "services.auth.jina.jina.JinaAuth"),
|
||||
(AuthType.JINA.value, "services.auth.jina.jina.JinaAuth"),
|
||||
],
|
||||
)
|
||||
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
|
||||
|
||||
Reference in New Issue
Block a user