mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
Merge branch 'main' into feat/end-user-oauth
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
@ -23,6 +24,8 @@ from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
@ -400,8 +403,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
|
||||
@ -35,6 +35,7 @@ from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.otel import WorkflowAppRunnerHandler, trace_span
|
||||
from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
@ -80,6 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
@ -770,7 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
tts_publisher.publish(None)
|
||||
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(
|
||||
self,
|
||||
|
||||
@ -99,6 +99,15 @@ class BaseAppGenerator:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Treat empty placeholders for optional file inputs as unset
|
||||
if (
|
||||
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
|
||||
and not variable_entity.required
|
||||
):
|
||||
# Treat empty string (frontend default) or empty list as unset
|
||||
if not value and isinstance(value, (str, list)):
|
||||
return None
|
||||
|
||||
if variable_entity.type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
|
||||
@ -156,7 +156,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
query = application_generate_entity.query or "New conversation"
|
||||
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
||||
|
||||
with db.session.begin():
|
||||
try:
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app_config.app_id,
|
||||
@ -232,7 +232,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
db.session.add_all(message_files)
|
||||
|
||||
db.session.commit()
|
||||
return conversation, message
|
||||
return conversation, message
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
raise
|
||||
|
||||
def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str:
|
||||
"""
|
||||
|
||||
@ -18,6 +18,7 @@ from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.otel import WorkflowAppRunnerHandler, trace_span
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
@ -56,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
"""
|
||||
Run application
|
||||
|
||||
@ -366,7 +366,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None):
|
||||
"""
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from threading import Thread
|
||||
from typing import Union
|
||||
|
||||
@ -31,6 +33,7 @@ from core.app.entities.task_entities import (
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@ -68,6 +71,8 @@ class MessageCycleManager:
|
||||
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
# time.sleep not block other logic
|
||||
time.sleep(1)
|
||||
thread = Thread(
|
||||
target=self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
@ -76,7 +81,7 @@ class MessageCycleManager:
|
||||
"query": query,
|
||||
},
|
||||
)
|
||||
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
@ -98,15 +103,23 @@ class MessageCycleManager:
|
||||
return
|
||||
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(
|
||||
app_model.tenant_id, query, conversation_id, conversation.app_id
|
||||
)
|
||||
conversation.name = name
|
||||
except Exception:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
query_hash = hashlib.md5(query.encode()).hexdigest()[:16]
|
||||
cache_key = f"conv_name:{conversation_id}:{query_hash}"
|
||||
|
||||
cached_name = redis_client.get(cache_key)
|
||||
if cached_name:
|
||||
name = cached_name.decode("utf-8")
|
||||
else:
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(
|
||||
app_model.tenant_id, query, conversation_id, conversation.app_id
|
||||
)
|
||||
redis_client.setex(cache_key, 3600, name)
|
||||
except Exception:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
|
||||
name = query[:47] + "..." if len(query) > 50 else query
|
||||
conversation.name = name
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
||||
@ -253,7 +253,7 @@ class ProviderConfiguration(BaseModel):
|
||||
try:
|
||||
credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("Failed to decrypt credential secret variable %s", key)
|
||||
|
||||
return self.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
@ -765,7 +765,7 @@ class ProviderConfiguration(BaseModel):
|
||||
try:
|
||||
credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("Failed to decrypt model credential secret variable %s", key)
|
||||
|
||||
current_credential_id = credential_record.id
|
||||
current_credential_name = credential_record.credential_name
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
import httpx
|
||||
@ -8,6 +9,7 @@ from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
|
||||
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str) -> str:
|
||||
@ -55,7 +57,9 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
try:
|
||||
result.append(MarketplacePluginDeclaration.model_validate(plugin))
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception(
|
||||
"Failed to deserialize marketplace plugin manifest for %s", plugin.get("plugin_id", "unknown")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
56
api/core/helper/tool_provider_cache.py
Normal file
56
api/core/helper/tool_provider_cache.py
Normal file
@ -0,0 +1,56 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolProviderListCache:
|
||||
"""Cache for tool provider lists"""
|
||||
|
||||
CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
|
||||
"""Generate cache key for tool providers list"""
|
||||
type_filter = typ or "all"
|
||||
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
|
||||
"""Get cached tool providers"""
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
cached_data = redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
try:
|
||||
return json.loads(cached_data.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
logger.warning("Failed to decode cached tool providers data")
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback()
|
||||
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
|
||||
"""Cache tool providers"""
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback()
|
||||
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
|
||||
"""Invalidate cache for tool providers"""
|
||||
if typ:
|
||||
# Invalidate specific type cache
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
redis_client.delete(cache_key)
|
||||
else:
|
||||
# Invalidate all caches for this tenant
|
||||
pattern = f"tool_providers:tenant_id:{tenant_id}:*"
|
||||
keys = list(redis_client.scan_iter(pattern))
|
||||
if keys:
|
||||
redis_client.delete(*keys)
|
||||
@ -15,6 +15,8 @@ from core.llm_generator.prompts import (
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
@ -124,7 +126,10 @@ class LLMGenerator:
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
model_parameters={
|
||||
"max_tokens": SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
"temperature": SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
|
||||
import os
|
||||
|
||||
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”.
|
||||
|
||||
1. Detect Input Language
|
||||
@ -94,7 +96,8 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
)
|
||||
|
||||
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
# Default prompt for suggested questions (can be overridden by environment variable)
|
||||
_DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keep each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
|
||||
@ -102,6 +105,15 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
||||
# Environment variable override for suggested questions prompt
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = os.getenv(
|
||||
"SUGGESTED_QUESTIONS_PROMPT", _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT
|
||||
)
|
||||
|
||||
# Configurable LLM parameters for suggested questions (can be overridden by environment variables)
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS = int(os.getenv("SUGGESTED_QUESTIONS_MAX_TOKENS", "256"))
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE = float(os.getenv("SUGGESTED_QUESTIONS_TEMPERATURE", "0"))
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
||||
" in the long text. Please think step by step."
|
||||
|
||||
@ -296,7 +296,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
|
||||
return node_span
|
||||
except Exception as e:
|
||||
logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
|
||||
logger.warning("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
def build_workflow_task_span(
|
||||
|
||||
@ -21,6 +21,7 @@ from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
|
||||
|
||||
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
|
||||
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
|
||||
@ -48,6 +49,7 @@ class TraceClient:
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
ACS_ARMS_SERVICE_FEATURE: "genai_app",
|
||||
}
|
||||
)
|
||||
self.span_builder = SpanBuilder(self.resource)
|
||||
@ -75,10 +77,10 @@ class TraceClient:
|
||||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
logger.warning("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
logger.debug("AliyunTrace API check failed: %s", str(e))
|
||||
logger.warning("AliyunTrace API check failed: %s", str(e))
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self) -> str:
|
||||
@ -116,7 +118,7 @@ class TraceClient:
|
||||
try:
|
||||
self.exporter.export(spans_to_export)
|
||||
except Exception as e:
|
||||
logger.debug("Error exporting spans: %s", e)
|
||||
logger.warning("Error exporting spans: %s", e)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
with self.condition:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from enum import StrEnum
|
||||
from typing import Final
|
||||
|
||||
ACS_ARMS_SERVICE_FEATURE: Final[str] = "acs.arms.service.feature"
|
||||
|
||||
# Public attributes
|
||||
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
|
||||
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"
|
||||
|
||||
@ -377,20 +377,20 @@ class OpsTraceManager:
|
||||
return app_model_config
|
||||
|
||||
@classmethod
|
||||
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
|
||||
def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str | None):
|
||||
"""
|
||||
Update app tracing config
|
||||
:param app_id: app id
|
||||
:param enabled: enabled
|
||||
:param tracing_provider: tracing provider
|
||||
:param tracing_provider: tracing provider (None when disabling)
|
||||
:return:
|
||||
"""
|
||||
# auth check
|
||||
try:
|
||||
if enabled or tracing_provider is not None:
|
||||
if tracing_provider is not None:
|
||||
try:
|
||||
provider_config_map[tracing_provider]
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
except KeyError:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app_config:
|
||||
|
||||
@ -521,4 +521,4 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
if hasattr(self, "trace_client"):
|
||||
self.trace_client.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
"""Document loader helpers."""
|
||||
|
||||
import concurrent.futures
|
||||
from typing import NamedTuple, cast
|
||||
from typing import NamedTuple
|
||||
|
||||
import charset_normalizer
|
||||
|
||||
|
||||
class FileEncoding(NamedTuple):
|
||||
@ -27,14 +29,14 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
|
||||
sample_size: The number of bytes to read for encoding detection. Default is 1MB.
|
||||
For large files, reading only a sample is sufficient and prevents timeout.
|
||||
"""
|
||||
import chardet
|
||||
|
||||
def read_and_detect(file_path: str):
|
||||
with open(file_path, "rb") as f:
|
||||
# Read only a sample of the file for encoding detection
|
||||
# This prevents timeout on large files while still providing accurate encoding detection
|
||||
rawdata = f.read(sample_size)
|
||||
return cast(list[dict], chardet.detect_all(rawdata))
|
||||
def read_and_detect(filename: str):
|
||||
rst = charset_normalizer.from_path(filename)
|
||||
best = rst.best()
|
||||
if best is None:
|
||||
return []
|
||||
file_encoding = FileEncoding(encoding=best.encoding, confidence=best.coherence, language=best.language)
|
||||
return [file_encoding]
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(read_and_detect, file_path)
|
||||
|
||||
@ -5,7 +5,7 @@ import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
@ -67,6 +67,11 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiProviderControllerItem(TypedDict):
|
||||
provider: ApiToolProvider
|
||||
controller: ApiToolProviderController
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
@ -655,9 +660,10 @@ class ToolManager:
|
||||
else:
|
||||
filters.append(typ)
|
||||
|
||||
with db.session.no_autoflush:
|
||||
# Use a single session for all database operations to reduce connection overhead
|
||||
with Session(db.engine) as session:
|
||||
if "builtin" in filters:
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
builtin_providers = list(cls.list_builtin_providers(tenant_id))
|
||||
|
||||
# key: provider name, value: provider
|
||||
db_builtin_providers = {
|
||||
@ -688,57 +694,74 @@ class ToolManager:
|
||||
|
||||
# get db api providers
|
||||
if "api" in filters:
|
||||
db_api_providers = db.session.scalars(
|
||||
db_api_providers = session.scalars(
|
||||
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
# Batch create controllers
|
||||
api_provider_controllers: list[ApiProviderControllerItem] = []
|
||||
for api_provider in db_api_providers:
|
||||
try:
|
||||
controller = ToolTransformService.api_provider_to_controller(api_provider)
|
||||
api_provider_controllers.append({"provider": api_provider, "controller": controller})
|
||||
except Exception:
|
||||
# Skip invalid providers but continue processing others
|
||||
logger.warning("Failed to create controller for API provider %s", api_provider.id)
|
||||
|
||||
# get labels
|
||||
labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
|
||||
|
||||
for api_provider_controller in api_provider_controllers:
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=api_provider_controller["controller"],
|
||||
db_provider=api_provider_controller["provider"],
|
||||
decrypt_credentials=False,
|
||||
labels=labels.get(api_provider_controller["controller"].provider_id, []),
|
||||
# Batch get labels for all API providers
|
||||
if api_provider_controllers:
|
||||
controllers = cast(
|
||||
list[ToolProviderController], [item["controller"] for item in api_provider_controllers]
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
labels = ToolLabelManager.get_tools_labels(controllers)
|
||||
|
||||
for item in api_provider_controllers:
|
||||
provider_controller = item["controller"]
|
||||
db_provider = item["provider"]
|
||||
provider_labels = labels.get(provider_controller.provider_id, [])
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=db_provider,
|
||||
decrypt_credentials=False,
|
||||
labels=provider_labels,
|
||||
)
|
||||
result_providers[f"api_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers = db.session.scalars(
|
||||
workflow_providers = session.scalars(
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for workflow_provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
workflow_controller: WorkflowToolProviderController = (
|
||||
ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
)
|
||||
workflow_provider_controllers.append(workflow_controller)
|
||||
except Exception:
|
||||
# app has been deleted
|
||||
pass
|
||||
logger.exception("Failed to transform workflow provider %s to controller", workflow_provider.id)
|
||||
continue
|
||||
# Batch get labels for workflow providers
|
||||
if workflow_provider_controllers:
|
||||
workflow_controllers: list[ToolProviderController] = [
|
||||
cast(ToolProviderController, controller) for controller in workflow_provider_controllers
|
||||
]
|
||||
labels = ToolLabelManager.get_tools_labels(workflow_controllers)
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
for workflow_provider_controller in workflow_provider_controllers:
|
||||
provider_labels = labels.get(workflow_provider_controller.provider_id, [])
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=workflow_provider_controller,
|
||||
labels=provider_labels,
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
labels=labels.get(provider_controller.provider_id, []),
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
if "mcp" in filters:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
|
||||
for mcp_provider in mcp_providers:
|
||||
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import charset_normalizer
|
||||
import cloudscraper
|
||||
from readabilipy import simple_json_from_html_string
|
||||
|
||||
@ -69,9 +69,12 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
||||
if response.status_code != 200:
|
||||
return f"URL returned status code {response.status_code}."
|
||||
|
||||
# Detect encoding using chardet
|
||||
detected_encoding = chardet.detect(response.content)
|
||||
encoding = detected_encoding["encoding"]
|
||||
# Detect encoding using charset_normalizer
|
||||
detected_encoding = charset_normalizer.from_bytes(response.content).best()
|
||||
if detected_encoding:
|
||||
encoding = detected_encoding.encoding
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
if encoding:
|
||||
try:
|
||||
content = response.content.decode(encoding)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
Unified event manager for collecting and emitting events.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
@ -12,6 +13,8 @@ from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from ..layers.base import GraphEngineLayer
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class ReadWriteLock:
|
||||
@ -180,5 +183,4 @@ class EventManager:
|
||||
try:
|
||||
layer.on_event(event)
|
||||
except Exception:
|
||||
# Silently ignore layer errors during collection
|
||||
pass
|
||||
_logger.exception("Error in layer on_event, layer_type=%s", type(layer))
|
||||
|
||||
@ -6,12 +6,15 @@ using the new Redis command channel, without requiring user permission checks.
|
||||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngineManager:
|
||||
@ -57,4 +60,4 @@ class GraphEngineManager:
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
# The legacy control mechanisms will still work
|
||||
pass
|
||||
logger.exception("Failed to send graph engine command %s for task %s", command.__class__.__name__, task_id)
|
||||
|
||||
@ -7,7 +7,7 @@ import tempfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
import chardet
|
||||
import charset_normalizer
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypandoc
|
||||
@ -228,9 +228,12 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
||||
|
||||
def _extract_text_from_plain_text(file_content: bytes) -> str:
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
# Detect encoding using charset_normalizer
|
||||
result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best()
|
||||
if result:
|
||||
encoding = result.encoding
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
@ -247,9 +250,12 @@ def _extract_text_from_plain_text(file_content: bytes) -> str:
|
||||
|
||||
def _extract_text_from_json(file_content: bytes) -> str:
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
# Detect encoding using charset_normalizer
|
||||
result = charset_normalizer.from_bytes(file_content).best()
|
||||
if result:
|
||||
encoding = result.encoding
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
@ -269,9 +275,12 @@ def _extract_text_from_json(file_content: bytes) -> str:
|
||||
def _extract_text_from_yaml(file_content: bytes) -> str:
|
||||
"""Extract the content from yaml file"""
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
# Detect encoding using charset_normalizer
|
||||
result = charset_normalizer.from_bytes(file_content).best()
|
||||
if result:
|
||||
encoding = result.encoding
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
@ -424,9 +433,12 @@ def _extract_text_from_file(file: File):
|
||||
|
||||
def _extract_text_from_csv(file_content: bytes) -> str:
|
||||
try:
|
||||
# Detect encoding using chardet
|
||||
result = chardet.detect(file_content)
|
||||
encoding = result["encoding"]
|
||||
# Detect encoding using charset_normalizer
|
||||
result = charset_normalizer.from_bytes(file_content).best()
|
||||
if result:
|
||||
encoding = result.encoding
|
||||
else:
|
||||
encoding = "utf-8"
|
||||
|
||||
# Fallback to utf-8 if detection fails
|
||||
if not encoding:
|
||||
|
||||
@ -64,7 +64,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
node_class = node_mapping.get(LATEST_VERSION)
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user