mirror of
https://github.com/langgenius/dify.git
synced 2026-03-12 10:38:54 +08:00
fix(telemetry): ensure CE safety for enterprise-only imports and DB lookups
- Move enqueue_draft_node_execution_trace import inside call site in workflow_service.py - Make gateway.py enterprise type imports lazy (routing dicts built on first call) - Restore typed ModelConfig in llm_generator method signatures (revert dict regression) - Fix generate_structured_output using wrong key model_parameters -> completion_params - Replace unsafe cast(str, msg.content) with get_text_content() across llm_generator - Remove duplicated payload classes from generator.py, import from core.llm_generator.entities - Gate _lookup_app_and_workspace_names and credential lookups in ops_trace_manager behind is_enterprise_telemetry_enabled()
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
@ -12,10 +11,12 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
@ -26,30 +27,13 @@ from services.workflow_service import WorkflowService
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
code_language: str = Field(default="javascript", description="Programming language for code generation")
|
||||
|
||||
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
node_id: str = Field(default="", description="Node ID for workflow context")
|
||||
current: str = Field(default="", description="Current instruction text")
|
||||
language: str = Field(default="javascript", description="Programming language (javascript/python)")
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
@ -18,3 +19,4 @@ class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Protocol, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.llm_generator.prompts import (
|
||||
@ -72,7 +73,7 @@ class LLMGenerator:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
answer = response.message.get_text_content()
|
||||
if answer is None:
|
||||
return ""
|
||||
try:
|
||||
@ -158,7 +159,7 @@ class LLMGenerator:
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
model_config: ModelConfig,
|
||||
no_variable: bool,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
@ -168,7 +169,7 @@ class LLMGenerator:
|
||||
error = ""
|
||||
error_step = ""
|
||||
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
model_parameters = model_config.completion_params
|
||||
if no_variable:
|
||||
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
||||
|
||||
@ -186,8 +187,8 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
llm_result = None
|
||||
@ -197,13 +198,13 @@ class LLMGenerator:
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
rule_config["prompt"] = cast(str, llm_result.message.content)
|
||||
rule_config["prompt"] = llm_result.message.get_text_content() or ""
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
|
||||
@ -250,8 +251,8 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
llm_result = None
|
||||
@ -284,7 +285,7 @@ class LLMGenerator:
|
||||
|
||||
return rule_config
|
||||
|
||||
rule_config["prompt"] = cast(str, prompt_content.message.content)
|
||||
rule_config["prompt"] = prompt_content.message.get_text_content() or ""
|
||||
|
||||
if not isinstance(prompt_content.message.content, str):
|
||||
raise NotImplementedError("prompt content is not a string")
|
||||
@ -311,7 +312,7 @@ class LLMGenerator:
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["variables"] = re.findall(
|
||||
r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)
|
||||
r'"\s*([^"]+)\s*"', prompt_content.message.get_text_content() or ""
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
@ -321,13 +322,13 @@ class LLMGenerator:
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = cast(str, statement_content.message.content)
|
||||
rule_config["opening_statement"] = statement_content.message.get_text_content() or ""
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
|
||||
@ -355,7 +356,7 @@ class LLMGenerator:
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
model_config: ModelConfig,
|
||||
code_language: str = "javascript",
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
@ -377,12 +378,12 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
model_parameters = model_config.completion_params
|
||||
|
||||
llm_result = None
|
||||
error = None
|
||||
@ -392,7 +393,7 @@ class LLMGenerator:
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_code = cast(str, llm_result.message.content)
|
||||
generated_code = llm_result.message.get_text_content() or ""
|
||||
result = {"code": generated_code, "language": code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
@ -400,7 +401,7 @@ class LLMGenerator:
|
||||
result = {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
|
||||
"Failed to invoke LLM model, model: %s, language: %s", model_config.name, code_language
|
||||
)
|
||||
error = str(e)
|
||||
result = {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
@ -445,26 +446,31 @@ class LLMGenerator:
|
||||
raise TypeError("Expected LLMResult when stream=False")
|
||||
response = result
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
answer = response.message.get_text_content() or ""
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(
|
||||
cls, tenant_id: str, instruction: str, model_config: dict, user_id: str | None = None, app_id: str | None = None
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: ModelConfig,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
|
||||
UserPromptMessage(content=instruction),
|
||||
]
|
||||
model_parameters = model_config.get("model_parameters", {})
|
||||
model_parameters = model_config.completion_params
|
||||
|
||||
llm_result = None
|
||||
error = None
|
||||
@ -496,7 +502,7 @@ class LLMGenerator:
|
||||
error = str(e)
|
||||
result = {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.name)
|
||||
error = str(e)
|
||||
result = {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@ -522,7 +528,7 @@ class LLMGenerator:
|
||||
flow_id: str,
|
||||
current: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
@ -570,7 +576,7 @@ class LLMGenerator:
|
||||
node_id: str,
|
||||
current: str,
|
||||
instruction: str,
|
||||
model_config: dict,
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
workflow_service: WorkflowServiceInterface,
|
||||
user_id: str | None = None,
|
||||
@ -647,7 +653,7 @@ class LLMGenerator:
|
||||
@staticmethod
|
||||
def __instruction_modify_common(
|
||||
tenant_id: str,
|
||||
model_config: dict,
|
||||
model_config: ModelConfig,
|
||||
last_run: dict | None,
|
||||
current: str | None,
|
||||
error_message: str | None,
|
||||
@ -670,8 +676,8 @@ class LLMGenerator:
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
match node_type:
|
||||
case "llm" | "agent":
|
||||
@ -719,9 +725,7 @@ class LLMGenerator:
|
||||
error = str(e)
|
||||
result = {"error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
|
||||
)
|
||||
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
|
||||
error = str(e)
|
||||
result = {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@ -758,7 +762,7 @@ class LLMGenerator:
|
||||
instruction: str,
|
||||
generated_output: str,
|
||||
llm_result: LLMResult | None,
|
||||
model_config: dict | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
timer=None,
|
||||
error: str | None = None,
|
||||
):
|
||||
@ -768,8 +772,8 @@ class LLMGenerator:
|
||||
total_tokens = llm_result.usage.total_tokens
|
||||
model_name = llm_result.model
|
||||
# Extract provider from model_config if available, otherwise fall back to parsing model name
|
||||
if model_config and model_config.get("provider"):
|
||||
model_provider = model_config.get("provider", "")
|
||||
if model_config and model_config.provider:
|
||||
model_provider = model_config.provider
|
||||
else:
|
||||
model_provider = model_name.split("/")[0] if "/" in model_name else ""
|
||||
latency = llm_result.usage.latency
|
||||
@ -779,8 +783,8 @@ class LLMGenerator:
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
model_provider = model_config.get("provider", "") if model_config else ""
|
||||
model_name = model_config.get("name", "") if model_config else ""
|
||||
model_provider = model_config.provider if model_config else ""
|
||||
model_name = model_config.name if model_config else ""
|
||||
latency = 0.0
|
||||
if timer:
|
||||
start_time = timer.get("start")
|
||||
|
||||
@ -39,8 +39,8 @@ from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from models.workflow import WorkflowAppLog
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
@ -94,13 +94,13 @@ def _lookup_llm_credential_info(
|
||||
"""
|
||||
Lookup LLM credential ID and name for the given provider and model.
|
||||
Returns (credential_id, credential_name).
|
||||
|
||||
|
||||
Handles async timing issues gracefully - if credential is deleted between lookups,
|
||||
returns the ID but empty name rather than failing.
|
||||
"""
|
||||
if not tenant_id or not provider:
|
||||
return None, ""
|
||||
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Try to find provider-level or model-level configuration
|
||||
@ -111,15 +111,15 @@ def _lookup_llm_credential_info(
|
||||
Provider.provider_type == ProviderType.CUSTOM,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if not provider_record:
|
||||
return None, ""
|
||||
|
||||
|
||||
# Check if there's a model-specific config
|
||||
credential_id = None
|
||||
credential_name = ""
|
||||
is_model_level = False
|
||||
|
||||
|
||||
if model and provider_record.credential_id:
|
||||
# Try model-level first
|
||||
model_record = session.scalar(
|
||||
@ -130,16 +130,16 @@ def _lookup_llm_credential_info(
|
||||
ProviderModel.model_type == model_type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if model_record and model_record.credential_id:
|
||||
credential_id = model_record.credential_id
|
||||
is_model_level = True
|
||||
|
||||
|
||||
if not credential_id and provider_record.credential_id:
|
||||
# Fall back to provider-level credential
|
||||
credential_id = provider_record.credential_id
|
||||
is_model_level = False
|
||||
|
||||
|
||||
# Lookup credential_name if we have credential_id
|
||||
if credential_id:
|
||||
try:
|
||||
@ -153,11 +153,9 @@ def _lookup_llm_credential_info(
|
||||
else:
|
||||
# Query ProviderCredential
|
||||
cred_name = session.scalar(
|
||||
select(ProviderCredential.credential_name).where(
|
||||
ProviderCredential.id == credential_id
|
||||
)
|
||||
select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id)
|
||||
)
|
||||
|
||||
|
||||
if cred_name:
|
||||
credential_name = str(cred_name)
|
||||
except Exception as e:
|
||||
@ -170,7 +168,7 @@ def _lookup_llm_credential_info(
|
||||
model,
|
||||
str(e),
|
||||
)
|
||||
|
||||
|
||||
return credential_id, credential_name
|
||||
except Exception as e:
|
||||
# Database query failed or other unexpected error
|
||||
@ -788,7 +786,12 @@ class TraceTask:
|
||||
)
|
||||
message_id = session.scalar(message_data_stmt)
|
||||
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"workflow_id": workflow_id,
|
||||
@ -867,7 +870,12 @@ class TraceTask:
|
||||
if tid:
|
||||
tenant_id = str(tid)
|
||||
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
metadata = {
|
||||
"conversation_id": message_data.conversation_id,
|
||||
@ -904,7 +912,9 @@ class TraceTask:
|
||||
outputs=message_data.answer,
|
||||
file_list=file_list,
|
||||
start_time=created_at,
|
||||
end_time=message_data.updated_at if message_data.updated_at and message_data.updated_at > created_at else created_at + timedelta(seconds=message_data.provider_response_latency),
|
||||
end_time=message_data.updated_at
|
||||
if message_data.updated_at and message_data.updated_at > created_at
|
||||
else created_at + timedelta(seconds=message_data.provider_response_latency),
|
||||
metadata=metadata,
|
||||
message_file_data=message_file_data,
|
||||
conversation_mode=conversation_mode,
|
||||
@ -1019,7 +1029,12 @@ class TraceTask:
|
||||
if tid:
|
||||
tenant_id = str(tid)
|
||||
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
doc_list = [doc.model_dump() for doc in documents] if documents else []
|
||||
dataset_ids: set[str] = set()
|
||||
@ -1269,24 +1284,32 @@ class TraceTask:
|
||||
if not node_data:
|
||||
return {}
|
||||
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(node_data.get("app_id"), node_data.get("tenant_id"))
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(
|
||||
node_data.get("app_id"), node_data.get("tenant_id")
|
||||
)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
# Try tool credential lookup first
|
||||
credential_id = node_data.get("credential_id")
|
||||
credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type"))
|
||||
|
||||
# If no credential_id found (e.g., LLM nodes), try LLM credential lookup
|
||||
if not credential_id:
|
||||
llm_cred_id, llm_cred_name = _lookup_llm_credential_info(
|
||||
tenant_id=node_data.get("tenant_id"),
|
||||
provider=node_data.get("model_provider"),
|
||||
model=node_data.get("model_name"),
|
||||
model_type="llm",
|
||||
)
|
||||
if llm_cred_id:
|
||||
credential_id = llm_cred_id
|
||||
credential_name = llm_cred_name
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type"))
|
||||
# If no credential_id found (e.g., LLM nodes), try LLM credential lookup
|
||||
if not credential_id:
|
||||
llm_cred_id, llm_cred_name = _lookup_llm_credential_info(
|
||||
tenant_id=node_data.get("tenant_id"),
|
||||
provider=node_data.get("model_provider"),
|
||||
model=node_data.get("model_name"),
|
||||
model_type="llm",
|
||||
)
|
||||
if llm_cred_id:
|
||||
credential_id = llm_cred_id
|
||||
credential_name = llm_cred_name
|
||||
else:
|
||||
credential_name = ""
|
||||
metadata: dict[str, Any] = {
|
||||
"tenant_id": node_data.get("tenant_id"),
|
||||
"app_id": node_data.get("app_id"),
|
||||
|
||||
@ -18,11 +18,11 @@ import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -32,40 +32,56 @@ PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024
|
||||
# Routing table — authoritative mapping for all editions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CASE_TO_TRACE_TASK: dict[TelemetryCase, TraceTaskName] = {
|
||||
TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE,
|
||||
TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE,
|
||||
TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE,
|
||||
TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE,
|
||||
TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE,
|
||||
}
|
||||
_case_to_trace_task: dict | None = None
|
||||
_case_routing: dict | None = None
|
||||
|
||||
TRACE_TASK_TO_CASE: dict[TraceTaskName, TelemetryCase] = {v: k for k, v in CASE_TO_TRACE_TASK.items()}
|
||||
|
||||
CASE_ROUTING: dict[TelemetryCase, CaseRoute] = {
|
||||
# TRACE — CE-eligible (flow in both CE and EE)
|
||||
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
# TRACE — enterprise-only
|
||||
TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
# METRIC_LOG — enterprise-only (signal-driven, not trace)
|
||||
TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
}
|
||||
def _get_case_to_trace_task() -> dict:
|
||||
global _case_to_trace_task
|
||||
if _case_to_trace_task is None:
|
||||
from enterprise.telemetry.contracts import TelemetryCase
|
||||
|
||||
_case_to_trace_task = {
|
||||
TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE,
|
||||
TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE,
|
||||
TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
|
||||
TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE,
|
||||
TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE,
|
||||
TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE,
|
||||
}
|
||||
return _case_to_trace_task
|
||||
|
||||
|
||||
def _get_case_routing() -> dict:
|
||||
global _case_routing
|
||||
if _case_routing is None:
|
||||
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase
|
||||
|
||||
_case_routing = {
|
||||
# TRACE — CE-eligible (flow in both CE and EE)
|
||||
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
|
||||
# TRACE — enterprise-only
|
||||
TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
|
||||
# METRIC_LOG — enterprise-only (signal-driven, not trace)
|
||||
TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
|
||||
}
|
||||
return _case_routing
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@ -132,7 +148,7 @@ def emit(
|
||||
METRIC_LOG events are dispatched to the enterprise Celery queue;
|
||||
silently dropped when enterprise telemetry is unavailable.
|
||||
"""
|
||||
route = CASE_ROUTING.get(case)
|
||||
route = _get_case_routing().get(case)
|
||||
if route is None:
|
||||
logger.warning("Unknown telemetry case: %s, dropping event", case)
|
||||
return
|
||||
@ -141,7 +157,7 @@ def emit(
|
||||
logger.debug("Dropping EE-only event: case=%s (EE disabled)", case)
|
||||
return
|
||||
|
||||
if route.signal_type is SignalType.TRACE:
|
||||
if route.signal_type == "trace":
|
||||
_emit_trace(case, context, payload, trace_manager)
|
||||
else:
|
||||
_emit_metric_log(case, context, payload)
|
||||
@ -156,7 +172,7 @@ def _emit_trace(
|
||||
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
|
||||
from core.ops.ops_trace_manager import TraceTask
|
||||
|
||||
trace_task_name = CASE_TO_TRACE_TASK.get(case)
|
||||
trace_task_name = _get_case_to_trace_task().get(case)
|
||||
if trace_task_name is None:
|
||||
logger.warning("No TraceTaskName mapping for case: %s", case)
|
||||
return
|
||||
@ -189,6 +205,8 @@ def _emit_metric_log(
|
||||
|
||||
payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id)
|
||||
|
||||
from enterprise.telemetry.contracts import TelemetryEnvelope
|
||||
|
||||
envelope = TelemetryEnvelope(
|
||||
case=case,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -14,14 +14,13 @@ from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.status import Status, StatusCode
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.models import File
|
||||
from core.variables import Segment
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
|
||||
def should_include_content() -> bool:
|
||||
@ -34,6 +33,7 @@ def should_include_content() -> bool:
|
||||
return True
|
||||
return dify_config.ENTERPRISE_INCLUDE_CONTENT
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
|
||||
"""
|
||||
Safely serialize objects to JSON, handling non-serializable types.
|
||||
|
||||
@ -43,7 +43,11 @@ class ToolNodeOTelParser:
|
||||
# Tool call arguments and result — gated by content policy
|
||||
if should_include_content():
|
||||
if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
|
||||
span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs))
|
||||
span.set_attribute(
|
||||
ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs)
|
||||
)
|
||||
|
||||
if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
|
||||
span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs))
|
||||
span.set_attribute(
|
||||
ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs)
|
||||
)
|
||||
|
||||
@ -27,7 +27,6 @@ from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
@ -734,6 +733,8 @@ class WorkflowService:
|
||||
with Session(db.engine) as session:
|
||||
outputs = workflow_node_execution.load_full_outputs(session, storage)
|
||||
|
||||
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
|
||||
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution,
|
||||
outputs=outputs,
|
||||
|
||||
Reference in New Issue
Block a user