Merge branch 'main' into feat/end-user-oauth

This commit is contained in:
zhsama
2025-12-08 16:49:57 +08:00
204 changed files with 7347 additions and 5943 deletions

View File

@ -1,4 +1,5 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any
@ -23,6 +24,8 @@ from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
logger = logging.getLogger(__name__)
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
@ -400,8 +403,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except:
pass
except Exception:
logger.exception("Failed to parse tool call from assistant message")
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
assert isinstance(message.content, str)

View File

@ -35,6 +35,7 @@ from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
@ -80,6 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
@trace_span(WorkflowAppRunnerHandler)
def run(self):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)

View File

@ -770,7 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
tts_publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
logger.debug("Conversation name generation running as daemon thread")
def _save_message(
self,

View File

@ -99,6 +99,15 @@ class BaseAppGenerator:
if value is None:
return None
# Treat empty placeholders for optional file inputs as unset
if (
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
and not variable_entity.required
):
# Treat empty string (frontend default) or empty list as unset
if not value and isinstance(value, (str, list)):
return None
if variable_entity.type in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,

View File

@ -156,7 +156,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
with db.session.begin():
try:
if not conversation:
conversation = Conversation(
app_id=app_config.app_id,
@ -232,7 +232,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add_all(message_files)
db.session.commit()
return conversation, message
return conversation, message
except Exception:
db.session.rollback()
raise
def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:
"""

View File

@ -18,6 +18,7 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
from models.workflow import Workflow
@ -56,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
@trace_span(WorkflowAppRunnerHandler)
def run(self):
"""
Run application

View File

@ -366,7 +366,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
logger.debug("Conversation name generation running as daemon thread")
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
"""

View File

@ -1,4 +1,6 @@
import hashlib
import logging
import time
from threading import Thread
from typing import Union
@ -31,6 +33,7 @@ from core.app.entities.task_entities import (
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
@ -68,6 +71,8 @@ class MessageCycleManager:
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
time.sleep(1)
thread = Thread(
target=self._generate_conversation_name_worker,
kwargs={
@ -76,7 +81,7 @@ class MessageCycleManager:
"query": query,
},
)
thread.daemon = True
thread.start()
return thread
@ -98,15 +103,23 @@ class MessageCycleManager:
return
# generate conversation name
try:
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, query, conversation_id, conversation.app_id
)
conversation.name = name
except Exception:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
query_hash = hashlib.md5(query.encode()).hexdigest()[:16]
cache_key = f"conv_name:{conversation_id}:{query_hash}"
cached_name = redis_client.get(cache_key)
if cached_name:
name = cached_name.decode("utf-8")
else:
try:
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, query, conversation_id, conversation.app_id
)
redis_client.setex(cache_key, 3600, name)
except Exception:
if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
name = query[:47] + "..." if len(query) > 50 else query
conversation.name = name
db.session.commit()
db.session.close()

View File

@ -253,7 +253,7 @@ class ProviderConfiguration(BaseModel):
try:
credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
except Exception:
pass
logger.exception("Failed to decrypt credential secret variable %s", key)
return self.obfuscated_credentials(
credentials=credentials,
@ -765,7 +765,7 @@ class ProviderConfiguration(BaseModel):
try:
credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
except Exception:
pass
logger.exception("Failed to decrypt model credential secret variable %s", key)
current_credential_id = credential_record.id
current_credential_name = credential_record.credential_name

View File

@ -1,3 +1,4 @@
import logging
from collections.abc import Sequence
import httpx
@ -8,6 +9,7 @@ from core.helper.download import download_with_size_limit
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
logger = logging.getLogger(__name__)
def get_plugin_pkg_url(plugin_unique_identifier: str) -> str:
@ -55,7 +57,9 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
try:
result.append(MarketplacePluginDeclaration.model_validate(plugin))
except Exception:
pass
logger.exception(
"Failed to deserialize marketplace plugin manifest for %s", plugin.get("plugin_id", "unknown")
)
return result

View File

@ -0,0 +1,56 @@
import json
import logging
from typing import Any
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from extensions.ext_redis import redis_client, redis_fallback
logger = logging.getLogger(__name__)
class ToolProviderListCache:
"""Cache for tool provider lists"""
CACHE_TTL = 300 # 5 minutes
@staticmethod
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
"""Generate cache key for tool providers list"""
type_filter = typ or "all"
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
@staticmethod
@redis_fallback(default_return=None)
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
"""Get cached tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
cached_data = redis_client.get(cache_key)
if cached_data:
try:
return json.loads(cached_data.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
logger.warning("Failed to decode cached tool providers data")
return None
return None
@staticmethod
@redis_fallback()
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
"""Cache tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
@staticmethod
@redis_fallback()
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
"""Invalidate cache for tool providers"""
if typ:
# Invalidate specific type cache
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.delete(cache_key)
else:
# Invalidate all caches for this tenant
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
keys = list(redis_client.scan_iter(pattern))
if keys:
redis_client.delete(*keys)

View File

@ -15,6 +15,8 @@ from core.llm_generator.prompts import (
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SUGGESTED_QUESTIONS_MAX_TOKENS,
SUGGESTED_QUESTIONS_TEMPERATURE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
@ -124,7 +126,10 @@ class LLMGenerator:
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
model_parameters={
"max_tokens": SUGGESTED_QUESTIONS_MAX_TOKENS,
"temperature": SUGGESTED_QUESTIONS_TEMPERATURE,
},
stream=False,
)

View File

@ -1,4 +1,6 @@
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
import os
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the users input into two parts: “Intention” and “Subject”.
1. Detect Input Language
@ -94,7 +96,8 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
)
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
# Default prompt for suggested questions (can be overridden by environment variable)
_DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keep each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
@ -102,6 +105,15 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
'["question1","question2","question3"]\n'
)
# Environment variable override for suggested questions prompt
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = os.getenv(
"SUGGESTED_QUESTIONS_PROMPT", _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT
)
# Configurable LLM parameters for suggested questions (can be overridden by environment variables)
SUGGESTED_QUESTIONS_MAX_TOKENS = int(os.getenv("SUGGESTED_QUESTIONS_MAX_TOKENS", "256"))
SUGGESTED_QUESTIONS_TEMPERATURE = float(os.getenv("SUGGESTED_QUESTIONS_TEMPERATURE", "0"))
GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step."

View File

@ -296,7 +296,7 @@ class AliyunDataTrace(BaseTraceInstance):
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
return node_span
except Exception as e:
logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
logger.warning("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
return None
def build_workflow_task_span(

View File

@ -21,6 +21,7 @@ from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
@ -48,6 +49,7 @@ class TraceClient:
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
ACS_ARMS_SERVICE_FEATURE: "genai_app",
}
)
self.span_builder = SpanBuilder(self.resource)
@ -75,10 +77,10 @@ class TraceClient:
if response.status_code == 405:
return True
else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
logger.warning("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False
except httpx.RequestError as e:
logger.debug("AliyunTrace API check failed: %s", str(e))
logger.warning("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
def get_project_url(self) -> str:
@ -116,7 +118,7 @@ class TraceClient:
try:
self.exporter.export(spans_to_export)
except Exception as e:
logger.debug("Error exporting spans: %s", e)
logger.warning("Error exporting spans: %s", e)
def shutdown(self) -> None:
with self.condition:

View File

@ -1,6 +1,8 @@
from enum import StrEnum
from typing import Final
ACS_ARMS_SERVICE_FEATURE: Final[str] = "acs.arms.service.feature"
# Public attributes
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"

View File

@ -377,20 +377,20 @@ class OpsTraceManager:
return app_model_config
@classmethod
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str | None):
"""
Update app tracing config
:param app_id: app id
:param enabled: enabled
:param tracing_provider: tracing provider
:param tracing_provider: tracing provider (None when disabling)
:return:
"""
# auth check
try:
if enabled or tracing_provider is not None:
if tracing_provider is not None:
try:
provider_config_map[tracing_provider]
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
except KeyError:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: App | None = db.session.query(App).where(App.id == app_id).first()
if not app_config:

View File

@ -521,4 +521,4 @@ class TencentDataTrace(BaseTraceInstance):
if hasattr(self, "trace_client"):
self.trace_client.shutdown()
except Exception:
pass
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")

View File

@ -1,7 +1,9 @@
"""Document loader helpers."""
import concurrent.futures
from typing import NamedTuple, cast
from typing import NamedTuple
import charset_normalizer
class FileEncoding(NamedTuple):
@ -27,14 +29,14 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
sample_size: The number of bytes to read for encoding detection. Default is 1MB.
For large files, reading only a sample is sufficient and prevents timeout.
"""
import chardet
def read_and_detect(file_path: str):
with open(file_path, "rb") as f:
# Read only a sample of the file for encoding detection
# This prevents timeout on large files while still providing accurate encoding detection
rawdata = f.read(sample_size)
return cast(list[dict], chardet.detect_all(rawdata))
def read_and_detect(filename: str):
rst = charset_normalizer.from_path(filename)
best = rst.best()
if best is None:
return []
file_encoding = FileEncoding(encoding=best.encoding, confidence=best.coherence, language=best.language)
return [file_encoding]
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(read_and_detect, file_path)

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast
import sqlalchemy as sa
from sqlalchemy import select
@ -67,6 +67,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ApiProviderControllerItem(TypedDict):
provider: ApiToolProvider
controller: ApiToolProviderController
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
@ -655,9 +660,10 @@ class ToolManager:
else:
filters.append(typ)
with db.session.no_autoflush:
# Use a single session for all database operations to reduce connection overhead
with Session(db.engine) as session:
if "builtin" in filters:
builtin_providers = cls.list_builtin_providers(tenant_id)
builtin_providers = list(cls.list_builtin_providers(tenant_id))
# key: provider name, value: provider
db_builtin_providers = {
@ -688,57 +694,74 @@ class ToolManager:
# get db api providers
if "api" in filters:
db_api_providers = db.session.scalars(
db_api_providers = session.scalars(
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
).all()
api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
for provider in db_api_providers
]
# Batch create controllers
api_provider_controllers: list[ApiProviderControllerItem] = []
for api_provider in db_api_providers:
try:
controller = ToolTransformService.api_provider_to_controller(api_provider)
api_provider_controllers.append({"provider": api_provider, "controller": controller})
except Exception:
# Skip invalid providers but continue processing others
logger.warning("Failed to create controller for API provider %s", api_provider.id)
# get labels
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
for api_provider_controller in api_provider_controllers:
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=api_provider_controller["controller"],
db_provider=api_provider_controller["provider"],
decrypt_credentials=False,
labels=labels.get(api_provider_controller["controller"].provider_id, []),
# Batch get labels for all API providers
if api_provider_controllers:
controllers = cast(
list[ToolProviderController], [item["controller"] for item in api_provider_controllers]
)
result_providers[f"api_provider.{user_provider.name}"] = user_provider
labels = ToolLabelManager.get_tools_labels(controllers)
for item in api_provider_controllers:
provider_controller = item["controller"]
db_provider = item["provider"]
provider_labels = labels.get(provider_controller.provider_id, [])
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=db_provider,
decrypt_credentials=False,
labels=provider_labels,
)
result_providers[f"api_provider.{user_provider.name}"] = user_provider
if "workflow" in filters:
# get workflow providers
workflow_providers = db.session.scalars(
workflow_providers = session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for workflow_provider in workflow_providers:
try:
workflow_provider_controllers.append(
workflow_controller: WorkflowToolProviderController = (
ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
)
workflow_provider_controllers.append(workflow_controller)
except Exception:
# app has been deleted
pass
logger.exception("Failed to transform workflow provider %s to controller", workflow_provider.id)
continue
# Batch get labels for workflow providers
if workflow_provider_controllers:
workflow_controllers: list[ToolProviderController] = [
cast(ToolProviderController, controller) for controller in workflow_provider_controllers
]
labels = ToolLabelManager.get_tools_labels(workflow_controllers)
labels = ToolLabelManager.get_tools_labels(
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
)
for workflow_provider_controller in workflow_provider_controllers:
provider_labels = labels.get(workflow_provider_controller.provider_id, [])
user_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=workflow_provider_controller,
labels=provider_labels,
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
for provider_controller in workflow_provider_controllers:
user_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=provider_controller,
labels=labels.get(provider_controller.provider_id, []),
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters:
with Session(db.engine) as session:
mcp_service = MCPToolManageService(session=session)
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
mcp_service = MCPToolManageService(session=session)
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
for mcp_provider in mcp_providers:
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
from typing import Any, cast
from urllib.parse import unquote
import chardet
import charset_normalizer
import cloudscraper
from readabilipy import simple_json_from_html_string
@ -69,9 +69,12 @@ def get_url(url: str, user_agent: str | None = None) -> str:
if response.status_code != 200:
return f"URL returned status code {response.status_code}."
# Detect encoding using chardet
detected_encoding = chardet.detect(response.content)
encoding = detected_encoding["encoding"]
# Detect encoding using charset_normalizer
detected_encoding = charset_normalizer.from_bytes(response.content).best()
if detected_encoding:
encoding = detected_encoding.encoding
else:
encoding = "utf-8"
if encoding:
try:
content = response.content.decode(encoding)

View File

@ -2,6 +2,7 @@
Unified event manager for collecting and emitting events.
"""
import logging
import threading
import time
from collections.abc import Generator
@ -12,6 +13,8 @@ from core.workflow.graph_events import GraphEngineEvent
from ..layers.base import GraphEngineLayer
_logger = logging.getLogger(__name__)
@final
class ReadWriteLock:
@ -180,5 +183,4 @@ class EventManager:
try:
layer.on_event(event)
except Exception:
# Silently ignore layer errors during collection
pass
_logger.exception("Error in layer on_event, layer_type=%s", type(layer))

View File

@ -6,12 +6,15 @@ using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
import logging
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@final
class GraphEngineManager:
@ -57,4 +60,4 @@ class GraphEngineManager:
except Exception:
# Silently fail if Redis is unavailable
# The legacy control mechanisms will still work
pass
logger.exception("Failed to send graph engine command %s for task %s", command.__class__.__name__, task_id)

View File

@ -7,7 +7,7 @@ import tempfile
from collections.abc import Mapping, Sequence
from typing import Any
import chardet
import charset_normalizer
import docx
import pandas as pd
import pypandoc
@ -228,9 +228,12 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
def _extract_text_from_plain_text(file_content: bytes) -> str:
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:
@ -247,9 +250,12 @@ def _extract_text_from_plain_text(file_content: bytes) -> str:
def _extract_text_from_json(file_content: bytes) -> str:
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:
@ -269,9 +275,12 @@ def _extract_text_from_json(file_content: bytes) -> str:
def _extract_text_from_yaml(file_content: bytes) -> str:
"""Extract the content from yaml file"""
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:
@ -424,9 +433,12 @@ def _extract_text_from_file(file: File):
def _extract_text_from_csv(file_content: bytes) -> str:
try:
# Detect encoding using chardet
result = chardet.detect(file_content)
encoding = result["encoding"]
# Detect encoding using charset_normalizer
result = charset_normalizer.from_bytes(file_content).best()
if result:
encoding = result.encoding
else:
encoding = "utf-8"
# Fallback to utf-8 if detection fails
if not encoding:

View File

@ -64,7 +64,10 @@ class DifyNodeFactory(NodeFactory):
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
node_class = node_mapping.get(LATEST_VERSION)
latest_node_class = node_mapping.get(LATEST_VERSION)
node_version = str(node_data.get("version", "1"))
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")