mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
Merge branch 'main' into feat/rag-2
This commit is contained in:
@ -90,23 +90,11 @@ class ChatMessageTextApi(Resource):
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
if text_to_speech is None:
|
||||
raise ValueError("TTS is not enabled")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
if app_model.app_model_config is None:
|
||||
raise ValueError("AppModelConfig not found")
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
||||
voice = args.get("voice", None)
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
|
||||
@ -18,7 +18,6 @@ from controllers.console.app.error import (
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -79,19 +78,9 @@ class ChatTextApi(InstalledAppResource):
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
||||
voice = args.get("voice", None)
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
|
||||
@ -87,7 +87,5 @@ class PluginUploadFileApi(Resource):
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return tool_file, 201
|
||||
|
||||
|
||||
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")
|
||||
|
||||
@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.model import App, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -78,20 +78,9 @@ class TextApi(Resource):
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
voice = args.get("voice", None)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -211,6 +211,9 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
if not indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
@ -301,6 +304,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from controllers.web.error import (
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppMode
|
||||
from models.model import App
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -77,21 +77,9 @@ class TextApi(WebApiResource):
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
voice = args.get("voice", None)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -19,6 +19,7 @@ from core.app.entities.task_entities import (
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from models.enums import MessageStatus
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -62,7 +63,7 @@ class BasedGenerateTaskPipeline:
|
||||
return err
|
||||
|
||||
err_desc = self._error_to_desc(err)
|
||||
message.status = "error"
|
||||
message.status = MessageStatus.ERROR
|
||||
message.error = err_desc
|
||||
return err
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ class File(BaseModel):
|
||||
# It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
|
||||
related_id: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||
extension: Optional[str] = Field(default=None, description="File extension, should contain dot")
|
||||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
|
||||
|
||||
@ -1,67 +0,0 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from constants import IMAGE_EXTENSIONS
|
||||
from core.helper.url_signer import UrlSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.exception(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file_id: the id of UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
result = UrlSigner.verify(
|
||||
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
|
||||
)
|
||||
|
||||
# verify signature
|
||||
if not result:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -28,7 +28,7 @@ class TemplateTransformer(ABC):
|
||||
def extract_result_str_from_response(cls, response: str):
|
||||
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
|
||||
if not result:
|
||||
raise ValueError("Failed to parse result")
|
||||
raise ValueError(f"Failed to parse result: no result tag found in response. Response: {response[:200]}...")
|
||||
return result.group(1)
|
||||
|
||||
@classmethod
|
||||
@ -38,16 +38,53 @@ class TemplateTransformer(ABC):
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
|
||||
try:
|
||||
result = json.loads(cls.extract_result_str_from_response(response))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("failed to parse response")
|
||||
result_str = cls.extract_result_str_from_response(response)
|
||||
result = json.loads(result_str)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...")
|
||||
except ValueError as e:
|
||||
# Re-raise ValueError from extract_result_str_from_response
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unexpected error during response transformation: {str(e)}")
|
||||
|
||||
# Check if the result contains an error
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
raise ValueError(f"JavaScript execution error: {result['error']}")
|
||||
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("result must be a dict")
|
||||
raise ValueError(f"Result must be a dict, got {type(result).__name__}")
|
||||
if not all(isinstance(k, str) for k in result):
|
||||
raise ValueError("result keys must be strings")
|
||||
raise ValueError("Result keys must be strings")
|
||||
|
||||
# Post-process the result to convert scientific notation strings back to numbers
|
||||
result = cls._post_process_result(result)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _post_process_result(cls, result: dict[Any, Any]) -> dict[Any, Any]:
|
||||
"""
|
||||
Post-process the result to convert scientific notation strings back to numbers
|
||||
"""
|
||||
|
||||
def convert_scientific_notation(value):
|
||||
if isinstance(value, str):
|
||||
# Check if the string looks like scientific notation
|
||||
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert_scientific_notation(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [convert_scientific_notation(v) for v in value]
|
||||
return value
|
||||
|
||||
return convert_scientific_notation(result) # type: ignore[no-any-return]
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, capacity: int):
|
||||
self.cache: OrderedDict[Any, Any] = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key: Any) -> Any:
|
||||
if key not in self.cache:
|
||||
return None
|
||||
else:
|
||||
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key: Any, value: Any) -> None:
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
self.cache[key] = value
|
||||
if len(self.cache) > self.capacity:
|
||||
self.cache.popitem(last=False) # pop the first item
|
||||
@ -317,9 +317,10 @@ class IndexingRunner:
|
||||
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
if image_file:
|
||||
storage.delete(image_file.key)
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Delete image_files failed while indexing_estimate, \
|
||||
|
||||
@ -23,6 +23,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
|
||||
|
||||
@ -170,10 +171,15 @@ def invoke_llm_with_structured_output(
|
||||
system_fingerprint: Optional[str] = None
|
||||
for event in llm_result:
|
||||
if isinstance(event, LLMResultChunk):
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
if isinstance(event.delta.message.content, str):
|
||||
result_text += event.delta.message.content
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
elif isinstance(event.delta.message.content, list):
|
||||
for item in event.delta.message.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
result_text += item.data
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
|
||||
@ -53,6 +53,37 @@ class LLMUsage(ModelUsage):
|
||||
latency=0.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, metadata: dict) -> "LLMUsage":
|
||||
"""
|
||||
Create LLMUsage instance from metadata dictionary with default values.
|
||||
|
||||
Args:
|
||||
metadata: Dictionary containing usage metadata
|
||||
|
||||
Returns:
|
||||
LLMUsage instance with values from metadata or defaults
|
||||
"""
|
||||
total_tokens = metadata.get("total_tokens", 0)
|
||||
completion_tokens = metadata.get("completion_tokens", 0)
|
||||
if total_tokens > 0 and completion_tokens == 0:
|
||||
completion_tokens = total_tokens
|
||||
|
||||
return cls(
|
||||
prompt_tokens=metadata.get("prompt_tokens", 0),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
|
||||
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
|
||||
total_price=Decimal(str(metadata.get("total_price", 0))),
|
||||
currency=metadata.get("currency", "USD"),
|
||||
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
|
||||
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
|
||||
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
|
||||
completion_price=Decimal(str(metadata.get("completion_price", 0))),
|
||||
latency=metadata.get("latency", 0.0),
|
||||
)
|
||||
|
||||
def plus(self, other: "LLMUsage") -> "LLMUsage":
|
||||
"""
|
||||
Add two LLMUsage instances together.
|
||||
|
||||
0
api/core/ops/aliyun_trace/__init__.py
Normal file
0
api/core/ops/aliyun_trace/__init__.py
Normal file
487
api/core/ops/aliyun_trace/aliyun_trace.py
Normal file
487
api/core/ops/aliyun_trace/aliyun_trace.py
Normal file
@ -0,0 +1,487 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_MODEL_NAME,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE,
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_SYSTEM,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes import NodeType
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AliyunDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
aliyun_config: AliyunConfig,
|
||||
):
|
||||
super().__init__(aliyun_config)
|
||||
base_url = aliyun_config.endpoint.rstrip("/")
|
||||
endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces")
|
||||
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
pass
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
pass
|
||||
|
||||
def api_check(self):
|
||||
return self.trace_client.api_check()
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
return self.trace_client.get_project_url()
|
||||
except Exception as e:
|
||||
logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Aliyun get run url failed: {str(e)}")
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
|
||||
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
|
||||
|
||||
workflow_node_executions = self.get_workflow_node_executions(trace_info)
|
||||
for node_execution in workflow_node_executions:
|
||||
node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
|
||||
self.trace_client.add_span(node_span)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
message_id = trace_info.message_id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
message_span_id = convert_to_span_id(message_id, "message")
|
||||
message_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
app_model_config = getattr(trace_info.message_data, "app_model_config", {})
|
||||
pre_prompt = getattr(app_model_config, "pre_prompt", "")
|
||||
inputs_data = getattr(trace_info.message_data, "inputs", {})
|
||||
llm_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=convert_to_span_id(message_id, "llm"),
|
||||
name="llm",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
|
||||
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False),
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt,
|
||||
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(trace_info.outputs),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(llm_span)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
message_id = trace_info.message_id
|
||||
|
||||
documents_data = extract_retrieval_documents(trace_info.documents)
|
||||
dataset_retrieval_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name="dataset_retrieval",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: str(trace_info.inputs),
|
||||
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
|
||||
INPUT_VALUE: str(trace_info.inputs),
|
||||
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
self.trace_client.add_span(dataset_retrieval_span)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
message_id = trace_info.message_id
|
||||
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
tool_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name=trace_info.tool_name,
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: trace_info.tool_name,
|
||||
TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.tool_outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(tool_span)
|
||||
|
||||
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
service_account.set_tenant_id(current_tenant.tenant_id)
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
return workflow_node_executions
|
||||
|
||||
def build_workflow_node_span(
|
||||
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
|
||||
):
|
||||
try:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
elif node_execution.node_type == NodeType.TOOL:
|
||||
node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
else:
|
||||
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
return node_span
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
|
||||
span_status: Status = Status(StatusCode.UNSET)
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
span_status = Status(StatusCode.OK)
|
||||
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
|
||||
span_status = Status(StatusCode.ERROR, str(node_execution.error))
|
||||
return span_status
|
||||
|
||||
def build_workflow_task_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def build_workflow_tool_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
tool_des = {}
|
||||
if node_execution.metadata:
|
||||
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: node_execution.title,
|
||||
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def build_workflow_retrieval_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
input_value = ""
|
||||
if node_execution.inputs:
|
||||
input_value = str(node_execution.inputs.get("query", ""))
|
||||
output_value = ""
|
||||
if node_execution.outputs:
|
||||
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: input_value,
|
||||
RETRIEVAL_DOCUMENT: output_value,
|
||||
INPUT_VALUE: input_value,
|
||||
OUTPUT_VALUE: output_value,
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def build_workflow_llm_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
process_data = node_execution.process_data or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||
GEN_AI_SYSTEM: process_data.get("model_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
|
||||
message_span_id = None
|
||||
if trace_info.message_id:
|
||||
message_span_id = convert_to_span_id(trace_info.message_id, "message")
|
||||
user_id = trace_info.metadata.get("user_id")
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
if message_span_id: # chatflow
|
||||
message_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
workflow_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=workflow_span_id,
|
||||
name="workflow",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_id = trace_info.message_id
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
suggested_question_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=convert_to_span_id(message_id, "suggested_question"),
|
||||
name="suggested_question",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
|
||||
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
|
||||
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(suggested_question_span)
|
||||
|
||||
|
||||
def extract_retrieval_documents(documents: list[Document]):
|
||||
documents_data = []
|
||||
for document in documents:
|
||||
document_data = {
|
||||
"content": document.page_content,
|
||||
"metadata": {
|
||||
"dataset_id": document.metadata.get("dataset_id"),
|
||||
"doc_id": document.metadata.get("doc_id"),
|
||||
"document_id": document.metadata.get("document_id"),
|
||||
},
|
||||
"score": document.metadata.get("score"),
|
||||
}
|
||||
documents_data.append(document_data)
|
||||
return documents_data
|
||||
0
api/core/ops/aliyun_trace/data_exporter/__init__.py
Normal file
0
api/core/ops/aliyun_trace/data_exporter/__init__.py
Normal file
200
api/core/ops/aliyun_trace/data_exporter/traceclient.py
Normal file
200
api/core/ops/aliyun_trace/data_exporter/traceclient.py
Normal file
@ -0,0 +1,200 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TraceClient:
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
endpoint: str,
|
||||
max_queue_size: int = 1000,
|
||||
schedule_delay_sec: int = 5,
|
||||
max_export_batch_size: int = 50,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
}
|
||||
)
|
||||
self.span_builder = SpanBuilder(self.resource)
|
||||
self.exporter = OTLPSpanExporter(endpoint=endpoint)
|
||||
|
||||
self.max_queue_size = max_queue_size
|
||||
self.schedule_delay_sec = schedule_delay_sec
|
||||
self.max_export_batch_size = max_export_batch_size
|
||||
|
||||
self.queue: deque = deque(maxlen=max_queue_size)
|
||||
self.condition = threading.Condition(threading.Lock())
|
||||
self.done = False
|
||||
|
||||
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
self._spans_dropped = False
|
||||
|
||||
def export(self, spans: Sequence[ReadableSpan]):
|
||||
self.exporter.export(spans)
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
response = requests.head(self.endpoint, timeout=5)
|
||||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.debug(f"AliyunTrace API check failed: {str(e)}")
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
return "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
def add_span(self, span_data: SpanData):
|
||||
if span_data is None:
|
||||
return
|
||||
span: ReadableSpan = self.span_builder.build_span(span_data)
|
||||
with self.condition:
|
||||
if len(self.queue) == self.max_queue_size:
|
||||
if not self._spans_dropped:
|
||||
logger.warning("Queue is full, likely spans will be dropped.")
|
||||
self._spans_dropped = True
|
||||
|
||||
self.queue.appendleft(span)
|
||||
if len(self.queue) >= self.max_export_batch_size:
|
||||
self.condition.notify()
|
||||
|
||||
def _worker(self):
|
||||
while not self.done:
|
||||
with self.condition:
|
||||
if len(self.queue) < self.max_export_batch_size and not self.done:
|
||||
self.condition.wait(timeout=self.schedule_delay_sec)
|
||||
self._export_batch()
|
||||
|
||||
def _export_batch(self):
|
||||
spans_to_export: list[ReadableSpan] = []
|
||||
with self.condition:
|
||||
while len(spans_to_export) < self.max_export_batch_size and self.queue:
|
||||
spans_to_export.append(self.queue.pop())
|
||||
|
||||
if spans_to_export:
|
||||
try:
|
||||
self.exporter.export(spans_to_export)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error exporting spans: {e}")
|
||||
|
||||
def shutdown(self):
|
||||
with self.condition:
|
||||
self.done = True
|
||||
self.condition.notify_all()
|
||||
self.worker_thread.join()
|
||||
self._export_batch()
|
||||
self.exporter.shutdown()
|
||||
|
||||
|
||||
class SpanBuilder:
|
||||
def __init__(self, resource):
|
||||
self.resource = resource
|
||||
self.instrumentation_scope = InstrumentationScope(
|
||||
__name__,
|
||||
"",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
def build_span(self, span_data: SpanData) -> ReadableSpan:
|
||||
span_context = trace_api.SpanContext(
|
||||
trace_id=span_data.trace_id,
|
||||
span_id=span_data.span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
|
||||
trace_state=None,
|
||||
)
|
||||
|
||||
parent_span_context = None
|
||||
if span_data.parent_span_id is not None:
|
||||
parent_span_context = trace_api.SpanContext(
|
||||
trace_id=span_data.trace_id,
|
||||
span_id=span_data.parent_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
|
||||
trace_state=None,
|
||||
)
|
||||
|
||||
span = ReadableSpan(
|
||||
name=span_data.name,
|
||||
context=span_context,
|
||||
parent=parent_span_context,
|
||||
resource=self.resource,
|
||||
attributes=span_data.attributes,
|
||||
events=span_data.events,
|
||||
links=span_data.links,
|
||||
kind=trace_api.SpanKind.INTERNAL,
|
||||
status=span_data.status,
|
||||
start_time=span_data.start_time,
|
||||
end_time=span_data.end_time,
|
||||
instrumentation_scope=self.instrumentation_scope,
|
||||
)
|
||||
return span
|
||||
|
||||
|
||||
def generate_span_id() -> int:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id
|
||||
|
||||
|
||||
def convert_to_trace_id(uuid_v4: Optional[str]) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
return uuid_obj.int
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
|
||||
|
||||
def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
|
||||
span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
return span_id
|
||||
|
||||
|
||||
def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]:
|
||||
if start_time_a is None:
|
||||
return None
|
||||
timestamp_in_seconds = start_time_a.timestamp()
|
||||
timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9)
|
||||
return timestamp_in_nanoseconds
|
||||
0
api/core/ops/aliyun_trace/entities/__init__.py
Normal file
0
api/core/ops/aliyun_trace/entities/__init__.py
Normal file
21
api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
Normal file
21
api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
Normal file
@ -0,0 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event, Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SpanData(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
trace_id: int = Field(..., description="The unique identifier for the trace.")
|
||||
parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.")
|
||||
span_id: int = Field(..., description="The unique identifier for this span.")
|
||||
name: str = Field(..., description="The name of the span.")
|
||||
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
|
||||
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.")
|
||||
64
api/core/ops/aliyun_trace/entities/semconv.py
Normal file
64
api/core/ops/aliyun_trace/entities/semconv.py
Normal file
@ -0,0 +1,64 @@
|
||||
from enum import Enum
|
||||
|
||||
# public
|
||||
GEN_AI_SESSION_ID = "gen_ai.session.id"
|
||||
|
||||
GEN_AI_USER_ID = "gen_ai.user.id"
|
||||
|
||||
GEN_AI_USER_NAME = "gen_ai.user.name"
|
||||
|
||||
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
|
||||
|
||||
GEN_AI_FRAMEWORK = "gen_ai.framework"
|
||||
|
||||
|
||||
# Chain
|
||||
INPUT_VALUE = "input.value"
|
||||
|
||||
OUTPUT_VALUE = "output.value"
|
||||
|
||||
|
||||
# Retriever
|
||||
RETRIEVAL_QUERY = "retrieval.query"
|
||||
|
||||
RETRIEVAL_DOCUMENT = "retrieval.document"
|
||||
|
||||
|
||||
# LLM
|
||||
GEN_AI_MODEL_NAME = "gen_ai.model_name"
|
||||
|
||||
GEN_AI_SYSTEM = "gen_ai.system"
|
||||
|
||||
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
|
||||
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
|
||||
|
||||
GEN_AI_PROMPT = "gen_ai.prompt"
|
||||
|
||||
GEN_AI_COMPLETION = "gen_ai.completion"
|
||||
|
||||
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||
|
||||
# Tool
|
||||
TOOL_NAME = "tool.name"
|
||||
|
||||
TOOL_DESCRIPTION = "tool.description"
|
||||
|
||||
TOOL_PARAMETERS = "tool.parameters"
|
||||
|
||||
|
||||
class GenAISpanKind(Enum):
|
||||
CHAIN = "CHAIN"
|
||||
RETRIEVER = "RETRIEVER"
|
||||
RERANKER = "RERANKER"
|
||||
LLM = "LLM"
|
||||
EMBEDDING = "EMBEDDING"
|
||||
TOOL = "TOOL"
|
||||
AGENT = "AGENT"
|
||||
TASK = "TASK"
|
||||
0
api/core/ops/arize_phoenix_trace/__init__.py
Normal file
0
api/core/ops/arize_phoenix_trace/__init__.py
Normal file
726
api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
Normal file
726
api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
Normal file
@ -0,0 +1,726 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
||||
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
|
||||
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
|
||||
try:
|
||||
# Choose the appropriate exporter based on config type
|
||||
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
|
||||
if isinstance(arize_phoenix_config, ArizeConfig):
|
||||
arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
|
||||
arize_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"space_id": arize_phoenix_config.space_id or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
}
|
||||
exporter = GrpcOTLPSpanExporter(
|
||||
endpoint=arize_endpoint,
|
||||
headers=arize_headers,
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
|
||||
phoenix_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
}
|
||||
exporter = HttpOTLPSpanExporter(
|
||||
endpoint=phoenix_endpoint,
|
||||
headers=phoenix_headers,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
attributes = {
|
||||
"openinference.project.name": arize_phoenix_config.project or "",
|
||||
"model_id": arize_phoenix_config.project or "",
|
||||
}
|
||||
resource = Resource(attributes=attributes)
|
||||
provider = trace_sdk.TracerProvider(resource=resource)
|
||||
processor = SimpleSpanProcessor(
|
||||
exporter,
|
||||
)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
# Create a named tracer instead of setting the global provider
|
||||
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
|
||||
logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}")
|
||||
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
|
||||
except Exception as e:
|
||||
logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def datetime_to_nanos(dt: Optional[datetime]) -> int:
|
||||
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
|
||||
if dt is None:
|
||||
dt = datetime.now()
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def uuid_to_trace_id(string: Optional[str]) -> int:
|
||||
"""Convert UUID string to a valid trace ID (16-byte integer)."""
|
||||
if string is None:
|
||||
string = ""
|
||||
hash_object = hashlib.sha256(string.encode())
|
||||
|
||||
# Take the first 16 bytes (128 bits) of the hash
|
||||
digest = hash_object.digest()[:16]
|
||||
|
||||
# Convert to integer (128 bits)
|
||||
return int.from_bytes(digest, byteorder="big")
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
arize_phoenix_config: ArizeConfig | PhoenixConfig,
|
||||
):
|
||||
super().__init__(arize_phoenix_config)
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
self.arize_phoenix_config = arize_phoenix_config
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info(f"[Arize/Phoenix] Trace: {trace_info}")
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
workflow_metadata = {
|
||||
"workflow_id": trace_info.workflow_run_id or "",
|
||||
"message_id": trace_info.message_id or "",
|
||||
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
|
||||
"status": trace_info.workflow_run_status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
# Process workflow nodes
|
||||
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
|
||||
node_metadata = {
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": node_execution.tenant_id,
|
||||
"app_id": node_execution.app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
|
||||
}
|
||||
|
||||
if node_execution.execution_metadata:
|
||||
node_metadata.update(json.loads(node_execution.execution_metadata))
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
if node_execution.node_type == "llm":
|
||||
span_kind = OpenInferenceSpanKindValues.LLM.value
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_metadata["ls_provider"] = provider
|
||||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {})
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
if usage_data:
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
|
||||
elif node_execution.node_type == "dataset_retrieval":
|
||||
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
|
||||
elif node_execution.node_type == "tool":
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL.value
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
|
||||
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
|
||||
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
)
|
||||
|
||||
try:
|
||||
if node_execution.node_type == "llm":
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
|
||||
if model:
|
||||
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {})
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
|
||||
)
|
||||
finally:
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_metadata = {
|
||||
"message_id": trace_info.message_id or "",
|
||||
"conversation_mode": str(trace_info.conversation_mode or ""),
|
||||
"user_id": trace_info.message_data.from_account_id or "",
|
||||
"file_list": json.dumps(file_list),
|
||||
"status": trace_info.message_data.status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
"prompt_tokens": trace_info.message_tokens or 0,
|
||||
"completion_tokens": trace_info.answer_tokens or 0,
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
message_metadata.update(trace_info.metadata)
|
||||
|
||||
# Add end user data if available
|
||||
if trace_info.message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
message_metadata["end_user_id"] = end_user_data.session_id
|
||||
|
||||
attributes = {
|
||||
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
message_span_id = RandomIdGenerator().generate_span_id()
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=message_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
message_span = self.tracer.start_span(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
attributes=attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
message_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
|
||||
# Convert outputs to string based on type
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
|
||||
elif isinstance(trace_info.outputs, str):
|
||||
outputs_str = trace_info.outputs
|
||||
else:
|
||||
outputs_str = str(trace_info.outputs)
|
||||
|
||||
llm_attributes = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: outputs_str,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
if isinstance(trace_info.inputs, list):
|
||||
for i, msg in enumerate(trace_info.inputs):
|
||||
if isinstance(msg, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
|
||||
"role", "user"
|
||||
)
|
||||
# todo: handle assistant and tool role messages, as they don't always
|
||||
# have a text field, but may have a tool_calls field instead
|
||||
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
|
||||
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
|
||||
elif isinstance(trace_info.inputs, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
elif isinstance(trace_info.inputs, str):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
|
||||
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
|
||||
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = trace_info.message_tokens
|
||||
if trace_info.answer_tokens is not None and trace_info.answer_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = trace_info.answer_tokens
|
||||
|
||||
if trace_info.message_data.model_id is not None:
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = trace_info.message_data.model_id
|
||||
if trace_info.message_data.model_provider is not None:
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
|
||||
|
||||
if trace_info.message_data and trace_info.message_data.message_metadata:
|
||||
metadata_dict = json.loads(trace_info.message_data.message_metadata)
|
||||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
llm_span = self.tracer.start_span(
|
||||
name="llm",
|
||||
attributes=llm_attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
llm_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
finally:
|
||||
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "moderation",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(
|
||||
{
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "suggested_question",
|
||||
"status": trace_info.status,
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
"ls_provider": trace_info.model_provider or "",
|
||||
"ls_model_name": trace_info.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "dataset_retrieval",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
"start_time": start_time.isoformat() if start_time else "",
|
||||
"end_time": end_time.isoformat() if end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
tool_span_id = RandomIdGenerator().generate_span_id()
|
||||
logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}")
|
||||
|
||||
# Create span context with the same trace_id as the parent
|
||||
# todo: Create with the appropriate parent span context, so that the tool span is
|
||||
# a child of the appropriate span (e.g. message span)
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=tool_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
tool_params_str = (
|
||||
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
|
||||
if isinstance(trace_info.tool_parameters, dict)
|
||||
else str(trace_info.tool_parameters)
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=trace_info.tool_name,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.TOOL_NAME: trace_info.tool_name,
|
||||
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"project_name": self.project,
|
||||
"message_id": trace_info.message_id,
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
|
||||
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
with self.tracer.start_span("api_check") as span:
|
||||
span.set_attribute("test", "true")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
|
||||
return "https://app.arize.com/"
|
||||
else:
|
||||
return f"{self.arize_phoenix_config.endpoint}/projects/"
|
||||
except Exception as e:
|
||||
logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.inputs,
|
||||
WorkflowNodeExecutionModel.outputs,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
return workflow_nodes
|
||||
@ -2,20 +2,92 @@ from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TracingProviderEnum(StrEnum):
|
||||
ARIZE = "arize"
|
||||
PHOENIX = "phoenix"
|
||||
LANGFUSE = "langfuse"
|
||||
LANGSMITH = "langsmith"
|
||||
OPIK = "opik"
|
||||
WEAVE = "weave"
|
||||
ALIYUN = "aliyun"
|
||||
|
||||
|
||||
class BaseTracingConfig(BaseModel):
|
||||
"""
|
||||
Base model class for tracing
|
||||
Base model class for tracing configurations
|
||||
"""
|
||||
|
||||
...
|
||||
@classmethod
|
||||
def validate_endpoint_url(cls, v: str, default_url: str) -> str:
|
||||
"""
|
||||
Common endpoint URL validation logic
|
||||
|
||||
Args:
|
||||
v: URL value to validate
|
||||
default_url: Default URL to use if input is None or empty
|
||||
|
||||
Returns:
|
||||
Validated and normalized URL
|
||||
"""
|
||||
return validate_url(v, default_url)
|
||||
|
||||
@classmethod
|
||||
def validate_project_field(cls, v: str, default_name: str) -> str:
|
||||
"""
|
||||
Common project name validation logic
|
||||
|
||||
Args:
|
||||
v: Project name to validate
|
||||
default_name: Default name to use if input is None or empty
|
||||
|
||||
Returns:
|
||||
Validated project name
|
||||
"""
|
||||
return validate_project_name(v, default_name)
|
||||
|
||||
|
||||
class ArizeConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Arize tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
space_id: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://otlp.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
|
||||
|
||||
|
||||
class PhoenixConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Phoenix tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://app.phoenix.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com")
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
@ -29,13 +101,8 @@ class LangfuseConfig(BaseTracingConfig):
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://api.langfuse.com"
|
||||
if not v.startswith("https://") and not v.startswith("http://"):
|
||||
raise ValueError("host must start with https:// or http://")
|
||||
|
||||
return v
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://api.langfuse.com")
|
||||
|
||||
|
||||
class LangSmithConfig(BaseTracingConfig):
|
||||
@ -49,13 +116,9 @@ class LangSmithConfig(BaseTracingConfig):
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://api.smith.langchain.com"
|
||||
if not v.startswith("https://"):
|
||||
raise ValueError("endpoint must start with https://")
|
||||
|
||||
return v
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# LangSmith only allows HTTPS
|
||||
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
@ -71,22 +134,12 @@ class OpikConfig(BaseTracingConfig):
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "Default Project"
|
||||
|
||||
return v
|
||||
return cls.validate_project_field(v, "Default Project")
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://www.comet.com/opik/api/"
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("url must start with https:// or http://")
|
||||
if not v.endswith("/api/"):
|
||||
raise ValueError("url should ends with /api/")
|
||||
|
||||
return v
|
||||
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
|
||||
|
||||
|
||||
class WeaveConfig(BaseTracingConfig):
|
||||
@ -102,22 +155,44 @@ class WeaveConfig(BaseTracingConfig):
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://trace.wandb.ai"
|
||||
if not v.startswith("https://"):
|
||||
raise ValueError("endpoint must start with https://")
|
||||
|
||||
return v
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# Weave only allows HTTPS for endpoint
|
||||
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def validate_host(cls, v, info: ValidationInfo):
|
||||
if v is not None and v != "":
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("host must start with https:// or http://")
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
if v is not None and v.strip() != "":
|
||||
return validate_url(v, v, allowed_schemes=("https", "http"))
|
||||
return v
|
||||
|
||||
|
||||
class AliyunConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Aliyun tracing config.
|
||||
"""
|
||||
|
||||
app_name: str = "dify_app"
|
||||
license_key: str
|
||||
endpoint: str
|
||||
|
||||
@field_validator("app_name")
|
||||
@classmethod
|
||||
def app_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
@field_validator("license_key")
|
||||
@classmethod
|
||||
def license_key_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("License key cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
@ -32,6 +32,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -180,12 +181,9 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
try:
|
||||
if outputs.get("usage"):
|
||||
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
|
||||
else:
|
||||
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
@ -293,7 +291,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
input=trace_info.inputs,
|
||||
output=message_data.answer,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
||||
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
|
||||
status_message=message_data.error or "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
@ -339,7 +337,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
||||
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
|
||||
status_message=message_data.error or "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
@ -206,12 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
try:
|
||||
if outputs.get("usage"):
|
||||
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
|
||||
else:
|
||||
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
|
||||
@ -222,10 +222,10 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
)
|
||||
|
||||
try:
|
||||
if outputs.get("usage"):
|
||||
total_tokens = outputs["usage"].get("total_tokens", 0)
|
||||
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
|
||||
completion_tokens = outputs["usage"].get("completion_tokens", 0)
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
total_tokens = usage_data.get("total_tokens", 0)
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
|
||||
@ -84,6 +84,36 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ARIZE:
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from core.ops.entities.config_entity import ArizeConfig
|
||||
|
||||
return {
|
||||
"config_class": ArizeConfig,
|
||||
"secret_keys": ["api_key", "space_id"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.PHOENIX:
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from core.ops.entities.config_entity import PhoenixConfig
|
||||
|
||||
return {
|
||||
"config_class": PhoenixConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ALIYUN:
|
||||
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
|
||||
return {
|
||||
"config_class": AliyunConfig,
|
||||
"secret_keys": ["license_key"],
|
||||
"other_keys": ["endpoint", "app_name"],
|
||||
"trace_instance": AliyunDataTrace,
|
||||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
@ -60,3 +61,83 @@ def generate_dotted_order(
|
||||
return current_segment
|
||||
|
||||
return f"{parent_dotted_order}.{current_segment}"
|
||||
|
||||
|
||||
def validate_url(url: str, default_url: str, allowed_schemes: tuple = ("https", "http")) -> str:
|
||||
"""
|
||||
Validate and normalize URL with proper error handling
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
default_url: Default URL to use if input is None or empty
|
||||
allowed_schemes: Tuple of allowed URL schemes (default: https, http)
|
||||
|
||||
Returns:
|
||||
Normalized URL string
|
||||
|
||||
Raises:
|
||||
ValueError: If URL format is invalid or scheme not allowed
|
||||
"""
|
||||
if not url or url.strip() == "":
|
||||
return default_url
|
||||
|
||||
# Parse URL to validate format
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Check if scheme is allowed
|
||||
if parsed.scheme not in allowed_schemes:
|
||||
raise ValueError(f"URL scheme must be one of: {', '.join(allowed_schemes)}")
|
||||
|
||||
# Reconstruct URL with only scheme, netloc (removing path, query, fragment)
|
||||
normalized_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
return normalized_url
|
||||
|
||||
|
||||
def validate_url_with_path(url: str, default_url: str, required_suffix: str | None = None) -> str:
|
||||
"""
|
||||
Validate URL that may include path components
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
default_url: Default URL to use if input is None or empty
|
||||
required_suffix: Optional suffix that URL must end with
|
||||
|
||||
Returns:
|
||||
Validated URL string
|
||||
|
||||
Raises:
|
||||
ValueError: If URL format is invalid or doesn't match required suffix
|
||||
"""
|
||||
if not url or url.strip() == "":
|
||||
return default_url
|
||||
|
||||
# Parse URL to validate format
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Check if scheme is allowed
|
||||
if parsed.scheme not in ("https", "http"):
|
||||
raise ValueError("URL must start with https:// or http://")
|
||||
|
||||
# Check required suffix if specified
|
||||
if required_suffix and not url.endswith(required_suffix):
|
||||
raise ValueError(f"URL should end with {required_suffix}")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def validate_project_name(project: str, default_name: str) -> str:
|
||||
"""
|
||||
Validate and normalize project name
|
||||
|
||||
Args:
|
||||
project: Project name to validate
|
||||
default_name: Default name to use if input is None or empty
|
||||
|
||||
Returns:
|
||||
Normalized project name
|
||||
"""
|
||||
if not project or project.strip() == "":
|
||||
return default_name
|
||||
|
||||
return project.strip()
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""Document loader helpers."""
|
||||
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Optional, cast
|
||||
|
||||
|
||||
@ -16,7 +15,7 @@ class FileEncoding(NamedTuple):
|
||||
"""The language of the file."""
|
||||
|
||||
|
||||
def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
|
||||
def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1024 * 1024) -> list[FileEncoding]:
|
||||
"""Try to detect the file encoding.
|
||||
|
||||
Returns a list of `FileEncoding` tuples with the detected encodings ordered
|
||||
@ -25,11 +24,16 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
|
||||
Args:
|
||||
file_path: The path to the file to detect the encoding for.
|
||||
timeout: The timeout in seconds for the encoding detection.
|
||||
sample_size: The number of bytes to read for encoding detection. Default is 1MB.
|
||||
For large files, reading only a sample is sufficient and prevents timeout.
|
||||
"""
|
||||
import chardet
|
||||
|
||||
def read_and_detect(file_path: str) -> list[dict]:
|
||||
rawdata = Path(file_path).read_bytes()
|
||||
with open(file_path, "rb") as f:
|
||||
# Read only a sample of the file for encoding detection
|
||||
# This prevents timeout on large files while still providing accurate encoding detection
|
||||
rawdata = f.read(sample_size)
|
||||
return cast(list[dict], chardet.detect_all(rawdata))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
|
||||
@ -36,8 +36,12 @@ class TextExtractor(BaseExtractor):
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Decode failed: {self._file_path}, all detected encodings failed. Original error: {e}"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
raise RuntimeError(f"Decode failed: {self._file_path}, specified encoding failed. Original error: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
|
||||
|
||||
@ -31,6 +31,14 @@ class TTSTool(BuiltinTool):
|
||||
model_type=ModelType.TTS,
|
||||
model=model,
|
||||
)
|
||||
if not voice:
|
||||
voices = model_instance.get_tts_voices()
|
||||
if voices:
|
||||
voice = voices[0].get("value")
|
||||
if not voice:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
else:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
tts = model_instance.invoke_tts(
|
||||
content_text=tool_parameters.get("text"), # type: ignore
|
||||
user=user_id,
|
||||
|
||||
@ -103,7 +103,7 @@ class GraphEngine:
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
@ -140,7 +140,7 @@ class GraphEngine:
|
||||
call_depth=call_depth,
|
||||
)
|
||||
|
||||
self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
||||
self.max_execution_steps = max_execution_steps
|
||||
self.max_execution_time = max_execution_time
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
@ -13,7 +14,7 @@ from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@ -102,6 +103,32 @@ class AgentNode(ToolNode):
|
||||
|
||||
try:
|
||||
# convert tool messages
|
||||
agent_thoughts: list = []
|
||||
|
||||
thought_log_message = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LOG,
|
||||
message=ToolInvokeMessage.LogMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
|
||||
parent_id=None,
|
||||
error=None,
|
||||
status=ToolInvokeMessage.LogMessage.LogStatus.START,
|
||||
data={
|
||||
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
"parameters": parameters_for_log,
|
||||
"thought_process": "Agent strategy execution started",
|
||||
},
|
||||
metadata={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def enhanced_message_stream():
|
||||
yield thought_log_message
|
||||
|
||||
yield from message_stream
|
||||
|
||||
yield from self._transform_message(
|
||||
message_stream,
|
||||
@ -110,6 +137,7 @@ class AgentNode(ToolNode):
|
||||
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
},
|
||||
parameters_for_log,
|
||||
agent_thoughts,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Literal
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import httpx
|
||||
from json_repair import repair_json
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_manager
|
||||
@ -178,7 +179,8 @@ class Executor:
|
||||
raise RequestBodyError("json body type should have exactly one item")
|
||||
json_string = self.variable_pool.convert_template(data[0].value).text
|
||||
try:
|
||||
json_object = json.loads(json_string, strict=False)
|
||||
repaired = repair_json(json_string)
|
||||
json_object = json.loads(repaired, strict=False)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
|
||||
self.json = json_object
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, wait
|
||||
@ -133,8 +134,11 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
@ -146,7 +150,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=iteration_graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
|
||||
@ -221,15 +221,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
)
|
||||
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=self.node_data.model,
|
||||
@ -253,6 +244,17 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
elif isinstance(event, LLMStructuredOutput):
|
||||
structured_output = event
|
||||
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
@ -101,8 +102,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
loop_variable_selectors[loop_variable.label] = variable_selector
|
||||
inputs[loop_variable.label] = processed_segment.value
|
||||
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
@ -114,7 +118,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=loop_graph,
|
||||
graph_config=self.graph_config,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
|
||||
@ -253,7 +253,12 @@ class ParameterExtractorNode(BaseNode):
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"__is_success": 1 if not error else 0, "__reason": error, **result},
|
||||
outputs={
|
||||
"__is_success": 1 if not error else 0,
|
||||
"__reason": error,
|
||||
"__usage": jsonable_encoder(usage),
|
||||
**result,
|
||||
},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
|
||||
@ -145,7 +145,11 @@ class QuestionClassifierNode(LLMNode):
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
outputs = {"class_name": category_name, "class_id": category_id}
|
||||
outputs = {
|
||||
"class_name": category_name,
|
||||
"class_id": category_id,
|
||||
"usage": jsonable_encoder(usage),
|
||||
}
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
@ -190,6 +191,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
agent_thoughts: Optional[list] = None,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
@ -208,7 +210,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
|
||||
llm_usage: LLMUsage | None = None
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
@ -276,9 +278,10 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(msg_metadata)
|
||||
agent_execution_metadata = {
|
||||
key: value
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
@ -366,17 +369,42 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: dict[str, Any] = {}
|
||||
if json:
|
||||
if isinstance(json, list) and len(json) == 1:
|
||||
# If json is a list with only one element, convert it to a dictionary
|
||||
json_output = json[0] if isinstance(json[0], dict) else {"data": json[0]}
|
||||
elif isinstance(json, list):
|
||||
# If json is a list with multiple elements, create a dictionary containing all data
|
||||
json_output = {"data": json}
|
||||
|
||||
if agent_logs:
|
||||
# Add agent_logs to json output
|
||||
json_output["agent_logs"] = [
|
||||
{
|
||||
"id": log.id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
for log in agent_logs
|
||||
]
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables},
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -69,6 +69,7 @@ class WorkflowEntry:
|
||||
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
|
||||
|
||||
# init workflow run state
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
self.graph_engine = GraphEngine(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
@ -80,7 +81,7 @@ class WorkflowEntry:
|
||||
call_depth=call_depth,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=thread_pool_id,
|
||||
|
||||
@ -17,6 +17,7 @@ class EnvironmentVariableField(fields.Raw):
|
||||
"name": value.name,
|
||||
"value": encrypter.obfuscated_token(value.value),
|
||||
"value_type": value.value_type.value,
|
||||
"description": value.description,
|
||||
}
|
||||
if isinstance(value, Variable):
|
||||
return {
|
||||
@ -24,6 +25,7 @@ class EnvironmentVariableField(fields.Raw):
|
||||
"name": value.name,
|
||||
"value": value.value,
|
||||
"value_type": value.value_type.value,
|
||||
"description": value.description,
|
||||
}
|
||||
if isinstance(value, dict):
|
||||
value_type = value.get("value_type")
|
||||
|
||||
@ -23,3 +23,12 @@ class DraftVariableType(StrEnum):
|
||||
NODE = "node"
|
||||
SYS = "sys"
|
||||
CONVERSATION = "conversation"
|
||||
|
||||
|
||||
class MessageStatus(StrEnum):
|
||||
"""
|
||||
Message Status Enum
|
||||
"""
|
||||
|
||||
NORMAL = "normal"
|
||||
ERROR = "error"
|
||||
|
||||
@ -4,6 +4,7 @@ version = "1.5.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
"arize-phoenix-otel~=0.9.2",
|
||||
"authlib==1.3.1",
|
||||
"azure-identity==1.16.1",
|
||||
"beautifulsoup4==4.12.2",
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from constants import AUDIO_EXTENSIONS
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from models.enums import MessageStatus
|
||||
from models.model import App, AppMode, AppModelConfig, Message
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -16,6 +20,7 @@ from services.errors.audio import (
|
||||
ProviderNotSupportTextToSpeechServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
FILE_SIZE = 30
|
||||
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
|
||||
@ -74,35 +79,36 @@ class AudioService:
|
||||
voice: Optional[str] = None,
|
||||
end_user: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
is_draft: bool = False,
|
||||
):
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
|
||||
from app import app
|
||||
from extensions.ext_database import db
|
||||
|
||||
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None):
|
||||
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
|
||||
with app.app_context():
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("TTS is not enabled")
|
||||
if voice is None:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if is_draft:
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
|
||||
else:
|
||||
workflow = app_model.workflow
|
||||
if (
|
||||
workflow is None
|
||||
or "text_to_speech" not in workflow.features_dict
|
||||
or not workflow.features_dict["text_to_speech"].get("enabled")
|
||||
):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
voice = workflow.features_dict["text_to_speech"].get("voice")
|
||||
else:
|
||||
if not is_draft:
|
||||
if app_model.app_model_config is None:
|
||||
raise ValueError("AppModelConfig not found")
|
||||
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
|
||||
|
||||
voice = features_dict["text_to_speech"].get("voice") if voice is None else voice
|
||||
else:
|
||||
if app_model.app_model_config is None:
|
||||
raise ValueError("AppModelConfig not found")
|
||||
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
|
||||
if not text_to_speech_dict.get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
if not text_to_speech_dict.get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = text_to_speech_dict.get("voice") if voice is None else voice
|
||||
voice = text_to_speech_dict.get("voice")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
@ -132,18 +138,18 @@ class AudioService:
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
if message is None:
|
||||
return None
|
||||
if message.answer == "" and message.status == "normal":
|
||||
if message.answer == "" and message.status == MessageStatus.NORMAL:
|
||||
return None
|
||||
|
||||
else:
|
||||
response = invoke_tts(message.answer, app_model=app_model, voice=voice)
|
||||
response = invoke_tts(text_content=message.answer, app_model=app_model, voice=voice, is_draft=is_draft)
|
||||
if isinstance(response, Generator):
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg")
|
||||
return response
|
||||
else:
|
||||
if text is None:
|
||||
raise ValueError("Text is required")
|
||||
response = invoke_tts(text, app_model, voice)
|
||||
response = invoke_tts(text_content=text, app_model=app_model, voice=voice, is_draft=is_draft)
|
||||
if isinstance(response, Generator):
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg")
|
||||
return response
|
||||
|
||||
@ -34,6 +34,24 @@ class OpsService:
|
||||
)
|
||||
new_decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
|
||||
|
||||
if tracing_provider == "arize" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://app.arize.com/"})
|
||||
|
||||
if tracing_provider == "phoenix" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://app.phoenix.arize.com/projects/"})
|
||||
|
||||
if tracing_provider == "langfuse" and (
|
||||
"project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key")
|
||||
):
|
||||
@ -76,6 +94,16 @@ class OpsService:
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://wandb.ai/"})
|
||||
|
||||
if tracing_provider == "aliyun" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://arms.console.aliyun.com/"})
|
||||
|
||||
trace_config_data.tracing_config = new_decrypt_tracing_config
|
||||
return trace_config_data.to_dict()
|
||||
|
||||
@ -107,7 +135,9 @@ class OpsService:
|
||||
return {"error": "Invalid Credentials"}
|
||||
|
||||
# get project url
|
||||
if tracing_provider == "langfuse":
|
||||
if tracing_provider in ("arize", "phoenix"):
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
|
||||
elif tracing_provider == "langfuse":
|
||||
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
|
||||
project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key)
|
||||
elif tracing_provider in ("langsmith", "opik"):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
from flask import Flask
|
||||
from packaging.version import Version
|
||||
from yarl import URL
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
@ -40,6 +41,9 @@ def test_dify_config(monkeypatch):
|
||||
|
||||
assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3
|
||||
|
||||
# values from pyproject.toml
|
||||
assert Version(config.project.version) >= Version("1.0.0")
|
||||
|
||||
|
||||
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
|
||||
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
|
||||
|
||||
194
api/tests/unit_tests/core/helper/test_url_signer.py
Normal file
194
api/tests/unit_tests/core/helper/test_url_signer.py
Normal file
@ -0,0 +1,194 @@
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.url_signer import SignedUrlParams, UrlSigner
|
||||
|
||||
|
||||
class TestUrlSigner:
|
||||
"""Test cases for UrlSigner class"""
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_generate_signed_url_params(self):
|
||||
"""Test generation of signed URL parameters with all required fields"""
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
params = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
|
||||
# Verify the returned object and required fields
|
||||
assert isinstance(params, SignedUrlParams)
|
||||
assert params.sign_key == sign_key
|
||||
assert params.timestamp is not None
|
||||
assert params.nonce is not None
|
||||
assert params.sign is not None
|
||||
|
||||
# Verify nonce format (32 character hex string)
|
||||
assert len(params.nonce) == 32
|
||||
assert all(c in "0123456789abcdef" for c in params.nonce)
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_generate_complete_signed_url(self):
|
||||
"""Test generation of complete signed URL with query parameters"""
|
||||
base_url = "https://example.com/api/test"
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
signed_url = UrlSigner.get_signed_url(base_url, sign_key, prefix)
|
||||
|
||||
# Parse URL and verify structure
|
||||
parsed = urlparse(signed_url)
|
||||
assert f"{parsed.scheme}://{parsed.netloc}{parsed.path}" == base_url
|
||||
|
||||
# Verify query parameters
|
||||
query_params = parse_qs(parsed.query)
|
||||
assert "timestamp" in query_params
|
||||
assert "nonce" in query_params
|
||||
assert "sign" in query_params
|
||||
|
||||
# Verify each parameter has exactly one value
|
||||
assert len(query_params["timestamp"]) == 1
|
||||
assert len(query_params["nonce"]) == 1
|
||||
assert len(query_params["sign"]) == 1
|
||||
|
||||
# Verify parameter values are not empty
|
||||
assert query_params["timestamp"][0]
|
||||
assert query_params["nonce"][0]
|
||||
assert query_params["sign"][0]
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_verify_valid_signature(self):
|
||||
"""Test verification of valid signature"""
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
# Generate and verify signature
|
||||
params = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
|
||||
is_valid = UrlSigner.verify(
|
||||
sign_key=sign_key, timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix=prefix
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
@pytest.mark.parametrize(
|
||||
("field", "modifier"),
|
||||
[
|
||||
("sign_key", lambda _: "wrong-sign-key"),
|
||||
("timestamp", lambda t: str(int(t) + 1000)),
|
||||
("nonce", lambda _: "different-nonce-123456789012345"),
|
||||
("prefix", lambda _: "wrong-prefix"),
|
||||
("sign", lambda s: s + "tampered"),
|
||||
],
|
||||
)
|
||||
def test_should_reject_invalid_signature_params(self, field, modifier):
|
||||
"""Test signature verification rejects invalid parameters"""
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
# Generate valid signed parameters
|
||||
params = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
|
||||
# Prepare verification parameters
|
||||
verify_params = {
|
||||
"sign_key": sign_key,
|
||||
"timestamp": params.timestamp,
|
||||
"nonce": params.nonce,
|
||||
"sign": params.sign,
|
||||
"prefix": prefix,
|
||||
}
|
||||
|
||||
# Modify the specific field
|
||||
verify_params[field] = modifier(verify_params[field])
|
||||
|
||||
# Verify should fail
|
||||
is_valid = UrlSigner.verify(**verify_params)
|
||||
assert is_valid is False
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", None)
|
||||
def test_should_raise_error_without_secret_key(self):
|
||||
"""Test that signing fails when SECRET_KEY is not configured"""
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
UrlSigner.get_signed_url_params("key", "prefix")
|
||||
|
||||
assert "SECRET_KEY is not set" in str(exc_info.value)
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_generate_unique_signatures(self):
|
||||
"""Test that different inputs produce different signatures"""
|
||||
params1 = UrlSigner.get_signed_url_params("key1", "prefix1")
|
||||
params2 = UrlSigner.get_signed_url_params("key2", "prefix2")
|
||||
|
||||
# Different inputs should produce different signatures
|
||||
assert params1.sign != params2.sign
|
||||
assert params1.nonce != params2.nonce
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_handle_special_characters(self):
|
||||
"""Test handling of special characters in parameters"""
|
||||
special_cases = [
|
||||
"test with spaces",
|
||||
"test/with/slashes",
|
||||
"test中文字符",
|
||||
]
|
||||
|
||||
for sign_key in special_cases:
|
||||
params = UrlSigner.get_signed_url_params(sign_key, "prefix")
|
||||
|
||||
# Should generate valid signature and verify correctly
|
||||
is_valid = UrlSigner.verify(
|
||||
sign_key=sign_key, timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix="prefix"
|
||||
)
|
||||
assert is_valid is True
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_ensure_nonce_randomness(self):
|
||||
"""Test that nonce is random for each generation - critical for security"""
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
# Generate multiple nonces
|
||||
nonces = set()
|
||||
for _ in range(5):
|
||||
params = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
nonces.add(params.nonce)
|
||||
|
||||
# All nonces should be unique
|
||||
assert len(nonces) == 5
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
@patch("time.time", return_value=1234567890)
|
||||
@patch("os.urandom", return_value=b"\xab\xcd\xef\x12\x34\x56\x78\x90\xab\xcd\xef\x12\x34\x56\x78\x90")
|
||||
def test_should_produce_consistent_signatures(self, mock_urandom, mock_time):
|
||||
"""Test that same inputs produce same signature - ensures deterministic behavior"""
|
||||
sign_key = "test-sign-key"
|
||||
prefix = "test-prefix"
|
||||
|
||||
# Generate signature multiple times with same inputs (time and nonce are mocked)
|
||||
params1 = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
params2 = UrlSigner.get_signed_url_params(sign_key, prefix)
|
||||
|
||||
# With mocked time and random, should produce identical results
|
||||
assert params1.timestamp == params2.timestamp
|
||||
assert params1.nonce == params2.nonce
|
||||
assert params1.sign == params2.sign
|
||||
|
||||
# Verify the signature is valid
|
||||
assert UrlSigner.verify(
|
||||
sign_key=sign_key, timestamp=params1.timestamp, nonce=params1.nonce, sign=params1.sign, prefix=prefix
|
||||
)
|
||||
|
||||
@patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345")
|
||||
def test_should_handle_empty_strings(self):
|
||||
"""Test handling of empty string parameters - common edge case"""
|
||||
# Empty sign_key and prefix should still work
|
||||
params = UrlSigner.get_signed_url_params("", "")
|
||||
assert params.sign is not None
|
||||
|
||||
# Should verify correctly
|
||||
is_valid = UrlSigner.verify(
|
||||
sign_key="", timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix=""
|
||||
)
|
||||
assert is_valid is True
|
||||
1
api/tests/unit_tests/core/ops/__init__.py
Normal file
1
api/tests/unit_tests/core/ops/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Unit tests for core ops module
|
||||
385
api/tests/unit_tests/core/ops/test_config_entity.py
Normal file
385
api/tests/unit_tests/core/ops/test_config_entity.py
Normal file
@ -0,0 +1,385 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.ops.entities.config_entity import (
|
||||
AliyunConfig,
|
||||
ArizeConfig,
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
OpikConfig,
|
||||
PhoenixConfig,
|
||||
TracingProviderEnum,
|
||||
WeaveConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestTracingProviderEnum:
|
||||
"""Test cases for TracingProviderEnum"""
|
||||
|
||||
def test_enum_values(self):
|
||||
"""Test that all expected enum values are present"""
|
||||
assert TracingProviderEnum.ARIZE == "arize"
|
||||
assert TracingProviderEnum.PHOENIX == "phoenix"
|
||||
assert TracingProviderEnum.LANGFUSE == "langfuse"
|
||||
assert TracingProviderEnum.LANGSMITH == "langsmith"
|
||||
assert TracingProviderEnum.OPIK == "opik"
|
||||
assert TracingProviderEnum.WEAVE == "weave"
|
||||
assert TracingProviderEnum.ALIYUN == "aliyun"
|
||||
|
||||
|
||||
class TestArizeConfig:
|
||||
"""Test cases for ArizeConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Arize configuration"""
|
||||
config = ArizeConfig(
|
||||
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.space_id == "test_space"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = ArizeConfig()
|
||||
assert config.api_key is None
|
||||
assert config.space_id is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = ArizeConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_project_validation_none(self):
|
||||
"""Test project validation with None value"""
|
||||
config = ArizeConfig(project=None)
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = ArizeConfig(endpoint="")
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="ftp://invalid.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="invalid.com")
|
||||
|
||||
|
||||
class TestPhoenixConfig:
|
||||
"""Test cases for PhoenixConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Phoenix configuration"""
|
||||
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.phoenix.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = PhoenixConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://app.phoenix.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = PhoenixConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1")
|
||||
assert config.endpoint == "https://custom.phoenix.com"
|
||||
|
||||
|
||||
class TestLangfuseConfig:
|
||||
"""Test cases for LangfuseConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Langfuse configuration"""
|
||||
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
|
||||
assert config.public_key == "public_key"
|
||||
assert config.secret_key == "secret_key"
|
||||
assert config.host == "https://custom.langfuse.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(public_key="public")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(secret_key="secret")
|
||||
|
||||
def test_host_validation_empty(self):
|
||||
"""Test host validation with empty value"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
|
||||
|
||||
class TestLangSmithConfig:
|
||||
"""Test cases for LangSmithConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid LangSmith configuration"""
|
||||
config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.smith.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangSmithConfig(api_key="key", project="project")
|
||||
assert config.endpoint == "https://api.smith.langchain.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
|
||||
|
||||
|
||||
class TestOpikConfig:
|
||||
"""Test cases for OpikConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Opik configuration"""
|
||||
config = OpikConfig(
|
||||
api_key="test_key",
|
||||
project="test_project",
|
||||
workspace="test_workspace",
|
||||
url="https://custom.comet.com/opik/api/",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.workspace == "test_workspace"
|
||||
assert config.url == "https://custom.comet.com/opik/api/"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = OpikConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.workspace is None
|
||||
assert config.url == "https://www.comet.com/opik/api/"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = OpikConfig(project="")
|
||||
assert config.project == "Default Project"
|
||||
|
||||
def test_url_validation_empty(self):
|
||||
"""Test URL validation with empty value"""
|
||||
config = OpikConfig(url="")
|
||||
assert config.url == "https://www.comet.com/opik/api/"
|
||||
|
||||
def test_url_validation_missing_suffix(self):
|
||||
"""Test URL validation requires /api/ suffix"""
|
||||
with pytest.raises(ValidationError, match="URL should end with /api/"):
|
||||
OpikConfig(url="https://custom.comet.com/opik/")
|
||||
|
||||
def test_url_validation_invalid_scheme(self):
|
||||
"""Test URL validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
OpikConfig(url="ftp://custom.comet.com/opik/api/")
|
||||
|
||||
|
||||
class TestWeaveConfig:
|
||||
"""Test cases for WeaveConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Weave configuration"""
|
||||
config = WeaveConfig(
|
||||
api_key="test_key",
|
||||
entity="test_entity",
|
||||
project="test_project",
|
||||
endpoint="https://custom.wandb.ai",
|
||||
host="https://custom.host.com",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.entity == "test_entity"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.wandb.ai"
|
||||
assert config.host == "https://custom.host.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = WeaveConfig(api_key="key", project="project")
|
||||
assert config.entity is None
|
||||
assert config.endpoint == "https://trace.wandb.ai"
|
||||
assert config.host is None
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
|
||||
|
||||
def test_host_validation_optional(self):
|
||||
"""Test host validation is optional but validates when provided"""
|
||||
config = WeaveConfig(api_key="key", project="project", host=None)
|
||||
assert config.host is None
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="")
|
||||
assert config.host == ""
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
|
||||
assert config.host == "https://valid.host.com"
|
||||
|
||||
def test_host_validation_invalid_scheme(self):
|
||||
"""Test host validation rejects invalid schemes when provided"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
|
||||
|
||||
|
||||
class TestAliyunConfig:
|
||||
"""Test cases for AliyunConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Aliyun configuration"""
|
||||
config = AliyunConfig(
|
||||
app_name="test_app",
|
||||
license_key="test_license_key",
|
||||
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
|
||||
)
|
||||
assert config.app_name == "test_app"
|
||||
assert config.license_key == "test_license_key"
|
||||
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="test_license")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_app_name_validation_empty(self):
|
||||
"""Test app_name validation with empty value"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
|
||||
)
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="")
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_license_key_required(self):
|
||||
"""Test that license_key is required and cannot be empty"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
class TestConfigIntegration:
|
||||
"""Integration tests for configuration classes"""
|
||||
|
||||
def test_all_configs_can_be_instantiated(self):
|
||||
"""Test that all config classes can be instantiated with valid data"""
|
||||
configs = [
|
||||
ArizeConfig(api_key="key"),
|
||||
PhoenixConfig(api_key="key"),
|
||||
LangfuseConfig(public_key="public", secret_key="secret"),
|
||||
LangSmithConfig(api_key="key", project="project"),
|
||||
OpikConfig(api_key="key"),
|
||||
WeaveConfig(api_key="key", project="project"),
|
||||
AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com"),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
assert config is not None
|
||||
|
||||
def test_url_normalization_consistency(self):
|
||||
"""Test that URL normalization works consistently across configs"""
|
||||
# Test that paths are removed from endpoints
|
||||
arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
|
||||
phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/")
|
||||
aliyun_config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
|
||||
assert arize_config.endpoint == "https://arize.com"
|
||||
assert phoenix_config.endpoint == "https://phoenix.com"
|
||||
assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_project_default_values(self):
|
||||
"""Test that project default values are set correctly"""
|
||||
arize_config = ArizeConfig(project="")
|
||||
phoenix_config = PhoenixConfig(project="")
|
||||
opik_config = OpikConfig(project="")
|
||||
aliyun_config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
|
||||
)
|
||||
|
||||
assert arize_config.project == "default"
|
||||
assert phoenix_config.project == "default"
|
||||
assert opik_config.project == "Default Project"
|
||||
assert aliyun_config.app_name == "dify_app"
|
||||
138
api/tests/unit_tests/core/ops/test_utils.py
Normal file
138
api/tests/unit_tests/core/ops/test_utils.py
Normal file
@ -0,0 +1,138 @@
|
||||
import pytest
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Test cases for validate_url function"""
|
||||
|
||||
def test_valid_https_url(self):
|
||||
"""Test valid HTTPS URL"""
|
||||
result = validate_url("https://example.com", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_valid_http_url(self):
|
||||
"""Test valid HTTP URL"""
|
||||
result = validate_url("http://example.com", "https://default.com")
|
||||
assert result == "http://example.com"
|
||||
|
||||
def test_url_with_path_removed(self):
|
||||
"""Test that URL path is removed during normalization"""
|
||||
result = validate_url("https://example.com/api/v1/test", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_with_query_removed(self):
|
||||
"""Test that URL query parameters are removed"""
|
||||
result = validate_url("https://example.com?param=value", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_url_with_fragment_removed(self):
|
||||
"""Test that URL fragments are removed"""
|
||||
result = validate_url("https://example.com#section", "https://default.com")
|
||||
assert result == "https://example.com"
|
||||
|
||||
def test_empty_url_returns_default(self):
|
||||
"""Test empty URL returns default"""
|
||||
result = validate_url("", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_none_url_returns_default(self):
|
||||
"""Test None URL returns default"""
|
||||
result = validate_url(None, "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_whitespace_url_returns_default(self):
|
||||
"""Test whitespace URL returns default"""
|
||||
result = validate_url(" ", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_invalid_scheme_raises_error(self):
|
||||
"""Test invalid scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("ftp://example.com", "https://default.com")
|
||||
|
||||
def test_no_scheme_raises_error(self):
|
||||
"""Test URL without scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("example.com", "https://default.com")
|
||||
|
||||
def test_custom_allowed_schemes(self):
|
||||
"""Test custom allowed schemes"""
|
||||
result = validate_url("https://example.com", "https://default.com", allowed_schemes=("https",))
|
||||
assert result == "https://example.com"
|
||||
|
||||
with pytest.raises(ValueError, match="URL scheme must be one of"):
|
||||
validate_url("http://example.com", "https://default.com", allowed_schemes=("https",))
|
||||
|
||||
|
||||
class TestValidateUrlWithPath:
|
||||
"""Test cases for validate_url_with_path function"""
|
||||
|
||||
def test_valid_url_with_path(self):
|
||||
"""Test valid URL with path"""
|
||||
result = validate_url_with_path("https://example.com/api/v1", "https://default.com")
|
||||
assert result == "https://example.com/api/v1"
|
||||
|
||||
def test_valid_url_with_required_suffix(self):
|
||||
"""Test valid URL with required suffix"""
|
||||
result = validate_url_with_path("https://example.com/api/", "https://default.com", required_suffix="/api/")
|
||||
assert result == "https://example.com/api/"
|
||||
|
||||
def test_url_without_required_suffix_raises_error(self):
|
||||
"""Test URL without required suffix raises error"""
|
||||
with pytest.raises(ValueError, match="URL should end with /api/"):
|
||||
validate_url_with_path("https://example.com/api", "https://default.com", required_suffix="/api/")
|
||||
|
||||
def test_empty_url_returns_default(self):
|
||||
"""Test empty URL returns default"""
|
||||
result = validate_url_with_path("", "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_none_url_returns_default(self):
|
||||
"""Test None URL returns default"""
|
||||
result = validate_url_with_path(None, "https://default.com")
|
||||
assert result == "https://default.com"
|
||||
|
||||
def test_invalid_scheme_raises_error(self):
|
||||
"""Test invalid scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
|
||||
validate_url_with_path("ftp://example.com", "https://default.com")
|
||||
|
||||
def test_no_scheme_raises_error(self):
|
||||
"""Test URL without scheme raises ValueError"""
|
||||
with pytest.raises(ValueError, match="URL must start with https:// or http://"):
|
||||
validate_url_with_path("example.com", "https://default.com")
|
||||
|
||||
|
||||
class TestValidateProjectName:
|
||||
"""Test cases for validate_project_name function"""
|
||||
|
||||
def test_valid_project_name(self):
|
||||
"""Test valid project name"""
|
||||
result = validate_project_name("my-project", "default")
|
||||
assert result == "my-project"
|
||||
|
||||
def test_empty_project_name_returns_default(self):
|
||||
"""Test empty project name returns default"""
|
||||
result = validate_project_name("", "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_none_project_name_returns_default(self):
|
||||
"""Test None project name returns default"""
|
||||
result = validate_project_name(None, "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_whitespace_project_name_returns_default(self):
|
||||
"""Test whitespace project name returns default"""
|
||||
result = validate_project_name(" ", "default")
|
||||
assert result == "default"
|
||||
|
||||
def test_project_name_with_whitespace_trimmed(self):
|
||||
"""Test project name with whitespace is trimmed"""
|
||||
result = validate_project_name(" my-project ", "default")
|
||||
assert result == "my-project"
|
||||
|
||||
def test_custom_default_name(self):
|
||||
"""Test custom default name"""
|
||||
result = validate_project_name("", "Custom Default")
|
||||
assert result == "Custom Default"
|
||||
@ -1,3 +1,4 @@
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -19,6 +20,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
@ -172,6 +174,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@ -183,7 +186,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
@ -299,6 +302,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@ -310,7 +314,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
@ -479,6 +483,7 @@ def test_run_branch(mock_close, mock_remove):
|
||||
user_inputs={"uid": "takato"},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@ -490,7 +495,7 @@ def test_run_branch(mock_close, mock_remove):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
@ -813,6 +818,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
@ -824,7 +830,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@ -11,6 +13,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
@ -163,15 +166,16 @@ class ContinueOnErrorTestHelper:
|
||||
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
|
||||
"""Helper method to create a graph engine instance for testing"""
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
variable_pool = {
|
||||
"system_variables": {
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "clear",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
"user_inputs": user_inputs or {"uid": "takato"},
|
||||
}
|
||||
user_inputs=user_inputs or {"uid": "takato"},
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
return GraphEngine(
|
||||
tenant_id="111",
|
||||
@ -184,7 +188,7 @@ class ContinueOnErrorTestHelper:
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
74
api/tests/unit_tests/libs/test_password.py
Normal file
74
api/tests/unit_tests/libs/test_password.py
Normal file
@ -0,0 +1,74 @@
|
||||
import base64
|
||||
import binascii
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
|
||||
|
||||
class TestValidPassword:
|
||||
"""Test password format validation"""
|
||||
|
||||
def test_should_accept_valid_passwords(self):
|
||||
"""Test accepting valid password formats"""
|
||||
assert valid_password("password123") == "password123"
|
||||
assert valid_password("test1234") == "test1234"
|
||||
assert valid_password("Test123456") == "Test123456"
|
||||
|
||||
def test_should_reject_invalid_passwords(self):
|
||||
"""Test rejecting invalid password formats"""
|
||||
# Too short
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
valid_password("abc123")
|
||||
assert "Password must contain letters and numbers" in str(exc_info.value)
|
||||
|
||||
# No numbers
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("abcdefgh")
|
||||
|
||||
# No letters
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("12345678")
|
||||
|
||||
# Empty
|
||||
with pytest.raises(ValueError):
|
||||
valid_password("")
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
"""Test password hashing and comparison"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test data"""
|
||||
self.password = "test123password"
|
||||
self.salt = os.urandom(16)
|
||||
self.salt_base64 = base64.b64encode(self.salt).decode()
|
||||
|
||||
password_hash = hash_password(self.password, self.salt)
|
||||
self.password_hash_base64 = base64.b64encode(password_hash).decode()
|
||||
|
||||
def test_should_verify_correct_password(self):
|
||||
"""Test correct password verification"""
|
||||
result = compare_password(self.password, self.password_hash_base64, self.salt_base64)
|
||||
assert result is True
|
||||
|
||||
def test_should_reject_wrong_password(self):
|
||||
"""Test rejection of incorrect passwords"""
|
||||
result = compare_password("wrongpassword", self.password_hash_base64, self.salt_base64)
|
||||
assert result is False
|
||||
|
||||
def test_should_handle_invalid_base64(self):
|
||||
"""Test handling of invalid base64 data"""
|
||||
# Invalid base64 hash
|
||||
with pytest.raises(binascii.Error):
|
||||
compare_password(self.password, "invalid_base64!", self.salt_base64)
|
||||
|
||||
# Invalid base64 salt
|
||||
with pytest.raises(binascii.Error):
|
||||
compare_password(self.password, self.password_hash_base64, "invalid_base64!")
|
||||
|
||||
def test_should_be_case_sensitive(self):
|
||||
"""Test password case sensitivity"""
|
||||
result = compare_password(self.password.upper(), self.password_hash_base64, self.salt_base64)
|
||||
assert result is False
|
||||
@ -0,0 +1,465 @@
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
|
||||
"""Create a mock LLMUsage with all required fields"""
|
||||
return LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("1"),
|
||||
prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("1"),
|
||||
completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"),
|
||||
currency="USD",
|
||||
latency=1.5,
|
||||
)
|
||||
|
||||
|
||||
def get_model_entity(provider: str, model_name: str, support_structure_output: bool = False) -> AIModelEntity:
|
||||
"""Create a mock AIModelEntity for testing"""
|
||||
model_schema = MagicMock()
|
||||
model_schema.model = model_name
|
||||
model_schema.provider = provider
|
||||
model_schema.model_type = ModelType.LLM
|
||||
model_schema.model_provider = provider
|
||||
model_schema.model_name = model_name
|
||||
model_schema.support_structure_output = support_structure_output
|
||||
model_schema.parameter_rules = []
|
||||
|
||||
return model_schema
|
||||
|
||||
|
||||
def get_model_instance() -> MagicMock:
|
||||
"""Create a mock ModelInstance for testing"""
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.provider = "openai"
|
||||
mock_instance.credentials = {}
|
||||
return mock_instance
|
||||
|
||||
|
||||
def test_structured_output_parser():
|
||||
"""Test cases for invoke_llm_with_structured_output function"""
|
||||
|
||||
testcases = [
|
||||
# Test case 1: Model with native structured output support, non-streaming
|
||||
{
|
||||
"name": "native_structured_output_non_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 2: Model with native structured output support, streaming
|
||||
{
|
||||
"name": "native_structured_output_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"name":'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 3: Model without native structured output support, non-streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_non_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="claude-3-sonnet",
|
||||
message=AssistantPromptMessage(content='{"answer": "test response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 4: Model without native structured output support, streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"answer": "test'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 5: Streaming with list content
|
||||
{
|
||||
"name": "streaming_with_list_content",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"data":'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=' "value"}'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 6: Error case - non-string LLM response content (non-streaming)
|
||||
{
|
||||
"name": "error_non_string_content_non_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content=None), # Non-string content
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": None,
|
||||
"should_raise": True,
|
||||
"expected_error": OutputParserError,
|
||||
},
|
||||
# Test case 7: JSON repair scenario
|
||||
{
|
||||
"name": "json_repair_scenario",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": "test"'), # Invalid JSON - missing closing brace
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 8: Model with parameter rules for response format
|
||||
{
|
||||
"name": "model_with_parameter_rules",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
MagicMock(name="response_format", options=["json_schema"], required=False),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"result": "success"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 9: Model without native support but with JSON response format rules
|
||||
{
|
||||
"name": "non_native_with_json_rules",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
MagicMock(name="response_format", options=["JSON"], required=False),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="claude-3-sonnet",
|
||||
message=AssistantPromptMessage(content='{"output": "result"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=8),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
]
|
||||
|
||||
for case in testcases:
|
||||
print(f"Running test case: {case['name']}")
|
||||
|
||||
# Setup model entity
|
||||
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
|
||||
|
||||
# Add parameter rules if specified
|
||||
if "parameter_rules" in case:
|
||||
model_schema.parameter_rules = case["parameter_rules"]
|
||||
|
||||
# Setup model instance
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = case["expected_llm_response"]
|
||||
|
||||
# Setup prompt messages
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Generate a response according to the schema."),
|
||||
]
|
||||
|
||||
if case["should_raise"]:
|
||||
# Test error cases
|
||||
with pytest.raises(case["expected_error"]): # noqa: PT012
|
||||
if case["stream"]:
|
||||
result_generator = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
)
|
||||
# Consume the generator to trigger the error
|
||||
list(result_generator)
|
||||
else:
|
||||
invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
)
|
||||
else:
|
||||
# Test successful cases
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Configure json_repair mock for cases that need it
|
||||
if case["name"] == "json_repair_scenario":
|
||||
mock_json_repair.return_value = {"name": "test"}
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
stream=case["stream"],
|
||||
model_parameters={"temperature": 0.7, "max_tokens": 100},
|
||||
user="test_user",
|
||||
)
|
||||
|
||||
if case["expected_result_type"] == "generator":
|
||||
# Test streaming results
|
||||
assert hasattr(result, "__iter__")
|
||||
chunks = list(result)
|
||||
assert len(chunks) > 0
|
||||
|
||||
# Verify all chunks are LLMResultChunkWithStructuredOutput
|
||||
for chunk in chunks[:-1]: # All except last
|
||||
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert chunk.model == case["model_name"]
|
||||
|
||||
# Last chunk should have structured output
|
||||
last_chunk = chunks[-1]
|
||||
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert last_chunk.structured_output is not None
|
||||
assert isinstance(last_chunk.structured_output, dict)
|
||||
else:
|
||||
# Test non-streaming results
|
||||
assert isinstance(result, case["expected_result_type"])
|
||||
assert result.model == case["model_name"]
|
||||
assert result.structured_output is not None
|
||||
assert isinstance(result.structured_output, dict)
|
||||
|
||||
# Verify model_instance.invoke_llm was called with correct parameters
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
assert call_args.kwargs["stream"] == case["stream"]
|
||||
assert call_args.kwargs["user"] == "test_user"
|
||||
assert "temperature" in call_args.kwargs["model_parameters"]
|
||||
assert "max_tokens" in call_args.kwargs["model_parameters"]
|
||||
|
||||
|
||||
def test_parse_structured_output_edge_cases():
|
||||
"""Test edge cases for structured output parsing"""
|
||||
|
||||
# Test case with list that contains dict (reasoning model scenario)
|
||||
testcase_list_with_dict = {
|
||||
"name": "list_with_dict_parsing",
|
||||
"provider": "deepseek",
|
||||
"model_name": "deepseek-r1",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="deepseek-r1",
|
||||
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
}
|
||||
|
||||
# Setup for list parsing test
|
||||
model_schema = get_model_entity(
|
||||
testcase_list_with_dict["provider"],
|
||||
testcase_list_with_dict["model_name"],
|
||||
testcase_list_with_dict["support_structure_output"],
|
||||
)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test reasoning")]
|
||||
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Mock json_repair to return a list with dict
|
||||
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=testcase_list_with_dict["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=testcase_list_with_dict["json_schema"],
|
||||
stream=testcase_list_with_dict["stream"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
assert result.structured_output == {"thought": "reasoning process"}
|
||||
|
||||
|
||||
def test_model_specific_schema_preparation():
|
||||
"""Test schema preparation for different model types"""
|
||||
|
||||
# Test Gemini model
|
||||
gemini_case = {
|
||||
"provider": "google",
|
||||
"model_name": "gemini-pro",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
|
||||
}
|
||||
|
||||
model_schema = get_model_entity(
|
||||
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
|
||||
)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
model="gemini-pro",
|
||||
message=AssistantPromptMessage(content='{"result": "true"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test")]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=gemini_case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=gemini_case["json_schema"],
|
||||
stream=gemini_case["stream"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
|
||||
# Verify model_instance.invoke_llm was called and check the schema preparation
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
# For Gemini, the schema should not have additionalProperties and boolean should be converted to string
|
||||
assert "json_schema" in call_args.kwargs["model_parameters"]
|
||||
650
api/uv.lock
generated
650
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user