mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
Compare commits
22 Commits
dependabot
...
feat/new-b
| Author | SHA1 | Date | |
|---|---|---|---|
| 662f6e3e68 | |||
| ae01a5d137 | |||
| ad6670ebcc | |||
| 8ca0917044 | |||
| b3870524d4 | |||
| c543188434 | |||
| f319a9e42f | |||
| 58241a89a5 | |||
| 422bf3506e | |||
| 6e745f9e9b | |||
| 4e50d55339 | |||
| b95cdabe26 | |||
| daa47c25bb | |||
| f1bcd6d715 | |||
| 8643ff43f5 | |||
| c5f30a47f0 | |||
| 37d438fa19 | |||
| 9503803997 | |||
| d6476f5434 | |||
| 80b4633e8f | |||
| 3888969af3 | |||
| 658ac15589 |
@ -20,13 +20,10 @@ class TenantUserPayload(BaseModel):
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
"""
|
||||
Get current user.
|
||||
Get current user
|
||||
|
||||
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
||||
As a result, it could only be considered as an end user id. Even when a
|
||||
concrete end-user ID is supplied, lookups must stay tenant-scoped so one
|
||||
tenant cannot bind another tenant's user record into the plugin request
|
||||
context.
|
||||
As a result, it could only be considered as an end user id.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
@ -45,14 +42,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
user_model = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
user_model = session.get(EndUser, user_id)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
|
||||
@ -299,9 +299,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
tool_instance = tool_instances.get(prompt_tool.name)
|
||||
if tool_instance:
|
||||
self.update_prompt_message_tool(tool_instance, prompt_tool)
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
|
||||
@ -318,28 +318,34 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
|
||||
def validate_provider_credentials(
|
||||
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||
):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
|
||||
:param session: optional database session
|
||||
:return:
|
||||
"""
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
with Session(db.engine) as session:
|
||||
def _validate(s: Session):
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
try:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.id == credential_id,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
# fix origin data
|
||||
if credential_record and credential_record.encrypted_config:
|
||||
if not credential_record.encrypted_config.startswith("{"):
|
||||
original_credentials = {"openai_api_key": credential_record.encrypted_config}
|
||||
@ -350,23 +356,31 @@ class ProviderConfiguration(BaseModel):
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
for key, value in credentials.items():
|
||||
# encrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
return validated_credentials
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
if session:
|
||||
return _validate(session)
|
||||
else:
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
|
||||
def _generate_provider_credential_name(self, session) -> str:
|
||||
"""
|
||||
@ -443,16 +457,14 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as pre_session:
|
||||
with Session(db.engine) as session:
|
||||
if credential_name:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
else:
|
||||
credential_name = self._generate_provider_credential_name(pre_session)
|
||||
credential_name = self._generate_provider_credential_name(session)
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
||||
provider_record = self._get_provider_record(session)
|
||||
try:
|
||||
new_record = ProviderCredential(
|
||||
@ -465,6 +477,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session.flush()
|
||||
|
||||
if not provider_record:
|
||||
# If provider record does not exist, create it
|
||||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
@ -517,15 +530,15 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as pre_session:
|
||||
with Session(db.engine) as session:
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=pre_session, exclude_id=credential_id
|
||||
credential_name=credential_name, session=session, exclude_id=credential_id
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
credentials = self.validate_provider_credentials(
|
||||
credentials=credentials, credential_id=credential_id, session=session
|
||||
)
|
||||
provider_record = self._get_provider_record(session)
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
@ -533,10 +546,12 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
raise ValueError("Credential record not found.")
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
@ -864,6 +879,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
):
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
@ -874,14 +890,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
|
||||
:return:
|
||||
"""
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
with Session(db.engine) as session:
|
||||
def _validate(s: Session):
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
try:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
@ -890,7 +908,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
original_credentials = (
|
||||
json.loads(credential_record.encrypted_config)
|
||||
if credential_record and credential_record.encrypted_config
|
||||
@ -899,23 +917,31 @@ class ProviderConfiguration(BaseModel):
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
for key, value in credentials.items():
|
||||
# decrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
return validated_credentials
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
if session:
|
||||
return _validate(session)
|
||||
else:
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||
@ -928,22 +954,20 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credentials: model credentials dict
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as pre_session:
|
||||
with Session(db.engine) as session:
|
||||
if credential_name:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=pre_session
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
else:
|
||||
credential_name = self._generate_custom_model_credential_name(
|
||||
model=model, model_type=model_type, session=pre_session
|
||||
model=model, model_type=model_type, session=session
|
||||
)
|
||||
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials, session=session
|
||||
)
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
try:
|
||||
@ -958,6 +982,7 @@ class ProviderConfiguration(BaseModel):
|
||||
session.add(credential)
|
||||
session.flush()
|
||||
|
||||
# save provider model
|
||||
if not provider_model_record:
|
||||
provider_model_record = ProviderModel(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -999,24 +1024,23 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as pre_session:
|
||||
with Session(db.engine) as session:
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
credential_name=credential_name,
|
||||
session=pre_session,
|
||||
session=session,
|
||||
exclude_id=credential_id,
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
session=session,
|
||||
)
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
@ -1031,6 +1055,7 @@ class ProviderConfiguration(BaseModel):
|
||||
raise ValueError("Credential record not found.")
|
||||
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
|
||||
@ -4,20 +4,7 @@ from collections.abc import Sequence
|
||||
from opentelemetry.trace import SpanKind
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
from dify_trace_aliyun.data_exporter.traceclient import (
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
build_endpoint,
|
||||
convert_datetime_to_nanoseconds,
|
||||
@ -25,8 +12,8 @@ from dify_trace_aliyun.data_exporter.traceclient import (
|
||||
convert_to_trace_id,
|
||||
generate_span_id,
|
||||
)
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
DIFY_APP_ID,
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
@ -45,7 +32,7 @@ from dify_trace_aliyun.entities.semconv import (
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from dify_trace_aliyun.utils import (
|
||||
from core.ops.aliyun_trace.utils import (
|
||||
create_common_span_attributes,
|
||||
create_links_from_trace_id,
|
||||
create_status_from_error,
|
||||
@ -57,6 +44,19 @@ from dify_trace_aliyun.utils import (
|
||||
get_workflow_node_status,
|
||||
serialize_json_data,
|
||||
)
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
|
||||
from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
|
||||
|
||||
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
|
||||
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
|
||||
@ -4,8 +4,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
@ -14,6 +13,7 @@ from dify_trace_aliyun.entities.semconv import (
|
||||
OUTPUT_VALUE,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
|
||||
|
||||
|
||||
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
|
||||
from dify_trace_aliyun.data_exporter.traceclient import create_link
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import create_link
|
||||
|
||||
links = []
|
||||
if trace_id:
|
||||
@ -25,6 +25,7 @@ from opentelemetry.util.types import AttributeValue
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -38,7 +39,6 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from models.model import EndUser, MessageFile
|
||||
@ -1,8 +1,8 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url
|
||||
from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TracingProviderEnum(StrEnum):
|
||||
@ -52,5 +52,220 @@ class BaseTracingConfig(BaseModel):
|
||||
return validate_project_name(v, default_name)
|
||||
|
||||
|
||||
class ArizeConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Arize tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
space_id: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://otlp.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
|
||||
|
||||
|
||||
class PhoenixConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Phoenix tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://app.phoenix.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://app.phoenix.arize.com")
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langfuse tracing config.
|
||||
"""
|
||||
|
||||
public_key: str
|
||||
secret_key: str
|
||||
host: str = "https://api.langfuse.com"
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://api.langfuse.com")
|
||||
|
||||
|
||||
class LangSmithConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langsmith tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
project: str
|
||||
endpoint: str = "https://api.smith.langchain.com"
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# LangSmith only allows HTTPS
|
||||
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Opik tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
workspace: str | None = None
|
||||
url: str = "https://www.comet.com/opik/api/"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "Default Project")
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
|
||||
|
||||
|
||||
class WeaveConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Weave tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
entity: str | None = None
|
||||
project: str
|
||||
endpoint: str = "https://trace.wandb.ai"
|
||||
host: str | None = None
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# Weave only allows HTTPS for endpoint
|
||||
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
if v is not None and v.strip() != "":
|
||||
return validate_url(v, v, allowed_schemes=("https", "http"))
|
||||
return v
|
||||
|
||||
|
||||
class AliyunConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Aliyun tracing config.
|
||||
"""
|
||||
|
||||
app_name: str = "dify_app"
|
||||
license_key: str
|
||||
endpoint: str
|
||||
|
||||
@field_validator("app_name")
|
||||
@classmethod
|
||||
def app_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
@field_validator("license_key")
|
||||
@classmethod
|
||||
def license_key_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("License key cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# aliyun uses two URL formats, which may include a URL path
|
||||
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
class TencentConfig(BaseTracingConfig):
|
||||
"""
|
||||
Tencent APM tracing config
|
||||
"""
|
||||
|
||||
token: str
|
||||
endpoint: str
|
||||
service_name: str
|
||||
|
||||
@field_validator("token")
|
||||
@classmethod
|
||||
def token_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Token cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
|
||||
|
||||
@field_validator("service_name")
|
||||
@classmethod
|
||||
def service_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
|
||||
class MLflowConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for MLflow tracing config.
|
||||
"""
|
||||
|
||||
tracking_uri: str = "http://localhost:5000"
|
||||
experiment_id: str = "0" # Default experiment id in MLflow is 0
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
@field_validator("tracking_uri")
|
||||
@classmethod
|
||||
def tracking_uri_validator(cls, v, info: ValidationInfo):
|
||||
if isinstance(v, str) and v.startswith("databricks"):
|
||||
raise ValueError(
|
||||
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
|
||||
)
|
||||
return validate_url_with_path(v, "http://localhost:5000")
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
class DatabricksConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Databricks (Databricks-managed MLflow) tracing config.
|
||||
"""
|
||||
|
||||
experiment_id: str
|
||||
host: str
|
||||
client_id: str | None = None
|
||||
client_secret: str | None = None
|
||||
personal_access_token: str | None = None
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
@ -16,6 +16,7 @@ from langfuse.api.commons.types.usage import Usage
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -27,10 +28,7 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.entities.langfuse_trace_entity import (
|
||||
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
GenerationUsage,
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
@ -38,6 +36,8 @@ from dify_trace_langfuse.entities.langfuse_trace_entity import (
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
@ -9,6 +9,7 @@ from langsmith.schemas import RunBase
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -20,14 +21,13 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_langsmith.config import LangSmithConfig
|
||||
from dify_trace_langsmith.entities.langsmith_trace_entity import (
|
||||
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
@ -11,6 +11,7 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -23,7 +24,6 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models import EndUser
|
||||
@ -10,6 +10,7 @@ from opik.id_helpers import uuid4_to_uuid7
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -22,7 +23,6 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_opik.config import OpikConfig
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
@ -204,117 +204,114 @@ class TracingProviderConfigEntry(TypedDict):
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
|
||||
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
|
||||
try:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
|
||||
return {
|
||||
"config_class": LangfuseConfig,
|
||||
"secret_keys": ["public_key", "secret_key"],
|
||||
"other_keys": ["host", "project_key"],
|
||||
"trace_instance": LangFuseDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": LangfuseConfig,
|
||||
"secret_keys": ["public_key", "secret_key"],
|
||||
"other_keys": ["host", "project_key"],
|
||||
"trace_instance": LangFuseDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.LANGSMITH:
|
||||
from dify_trace_langsmith.config import LangSmithConfig
|
||||
from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
|
||||
case TracingProviderEnum.LANGSMITH:
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
|
||||
return {
|
||||
"config_class": LangSmithConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": LangSmithConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.OPIK:
|
||||
from dify_trace_opik.config import OpikConfig
|
||||
from dify_trace_opik.opik_trace import OpikDataTrace
|
||||
case TracingProviderEnum.OPIK:
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace
|
||||
|
||||
return {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.WEAVE:
|
||||
from dify_trace_weave.config import WeaveConfig
|
||||
from dify_trace_weave.weave_trace import WeaveDataTrace
|
||||
case TracingProviderEnum.WEAVE:
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.weave_trace.weave_trace import WeaveDataTrace
|
||||
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ARIZE:
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ARIZE:
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from core.ops.entities.config_entity import ArizeConfig
|
||||
|
||||
return {
|
||||
"config_class": ArizeConfig,
|
||||
"secret_keys": ["api_key", "space_id"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.PHOENIX:
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from dify_trace_arize_phoenix.config import PhoenixConfig
|
||||
return {
|
||||
"config_class": ArizeConfig,
|
||||
"secret_keys": ["api_key", "space_id"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.PHOENIX:
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from core.ops.entities.config_entity import PhoenixConfig
|
||||
|
||||
return {
|
||||
"config_class": PhoenixConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ALIYUN:
|
||||
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
return {
|
||||
"config_class": PhoenixConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": ArizePhoenixDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ALIYUN:
|
||||
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
|
||||
return {
|
||||
"config_class": AliyunConfig,
|
||||
"secret_keys": ["license_key"],
|
||||
"other_keys": ["endpoint", "app_name"],
|
||||
"trace_instance": AliyunDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.MLFLOW:
|
||||
from dify_trace_mlflow.config import MLflowConfig
|
||||
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
|
||||
return {
|
||||
"config_class": AliyunConfig,
|
||||
"secret_keys": ["license_key"],
|
||||
"other_keys": ["endpoint", "app_name"],
|
||||
"trace_instance": AliyunDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.MLFLOW:
|
||||
from core.ops.entities.config_entity import MLflowConfig
|
||||
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
|
||||
|
||||
return {
|
||||
"config_class": MLflowConfig,
|
||||
"secret_keys": ["password"],
|
||||
"other_keys": ["tracking_uri", "experiment_id", "username"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.DATABRICKS:
|
||||
from dify_trace_mlflow.config import DatabricksConfig
|
||||
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
|
||||
return {
|
||||
"config_class": MLflowConfig,
|
||||
"secret_keys": ["password"],
|
||||
"other_keys": ["tracking_uri", "experiment_id", "username"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.DATABRICKS:
|
||||
from core.ops.entities.config_entity import DatabricksConfig
|
||||
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
|
||||
|
||||
return {
|
||||
"config_class": DatabricksConfig,
|
||||
"secret_keys": ["personal_access_token", "client_secret"],
|
||||
"other_keys": ["host", "client_id", "experiment_id"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": DatabricksConfig,
|
||||
"secret_keys": ["personal_access_token", "client_secret"],
|
||||
"other_keys": ["host", "client_id", "experiment_id"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.TENCENT:
|
||||
from dify_trace_tencent.config import TencentConfig
|
||||
from dify_trace_tencent.tencent_trace import TencentDataTrace
|
||||
case TracingProviderEnum.TENCENT:
|
||||
from core.ops.entities.config_entity import TencentConfig
|
||||
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
|
||||
|
||||
return {
|
||||
"config_class": TencentConfig,
|
||||
"secret_keys": ["token"],
|
||||
"other_keys": ["endpoint", "service_name"],
|
||||
"trace_instance": TencentDataTrace,
|
||||
}
|
||||
return {
|
||||
"config_class": TencentConfig,
|
||||
"secret_keys": ["token"],
|
||||
"other_keys": ["endpoint", "service_name"],
|
||||
"trace_instance": TencentDataTrace,
|
||||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
except ImportError:
|
||||
raise ImportError(f"Provider {provider} is not installed.")
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
|
||||
|
||||
provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
@ -14,8 +14,7 @@ from core.ops.entities.trace_entity import (
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from dify_trace_tencent.entities.semconv import (
|
||||
from core.ops.tencent_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_IS_ENTRY,
|
||||
@ -39,8 +38,9 @@ from dify_trace_tencent.entities.semconv import (
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
|
||||
from dify_trace_tencent.utils import TencentTraceUtils
|
||||
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||
from core.ops.tencent_trace.utils import TencentTraceUtils
|
||||
from core.rag.models.document import Document
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
@ -8,6 +8,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import TencentConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -18,12 +19,11 @@ from core.ops.entities.trace_entity import (
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.client import TencentTraceClient
|
||||
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
|
||||
from core.ops.tencent_trace.utils import TencentTraceUtils
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from dify_trace_tencent.client import TencentTraceClient
|
||||
from dify_trace_tencent.config import TencentConfig
|
||||
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
|
||||
from dify_trace_tencent.span_builder import TencentSpanBuilder
|
||||
from dify_trace_tencent.utils import TencentTraceUtils
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
@ -17,6 +17,7 @@ from weave.trace_server.trace_server_interface import (
|
||||
)
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
@ -28,9 +29,8 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_weave.config import WeaveConfig
|
||||
from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
@ -1,56 +1,17 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum, auto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota consumption operation.
|
||||
|
||||
Attributes:
|
||||
success: Whether the quota charge succeeded
|
||||
charge_id: UUID for refund, or None if failed/disabled
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None
|
||||
_quota_type: "QuotaType"
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Refund this quota charge.
|
||||
|
||||
Safe to call even if charge failed or was disabled.
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if self.charge_id:
|
||||
self._quota_type.refund(self.charge_id)
|
||||
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
|
||||
|
||||
|
||||
class QuotaType(StrEnum):
|
||||
"""
|
||||
Supported quota types for tenant feature usage.
|
||||
|
||||
Add additional types here whenever new billable features become available.
|
||||
"""
|
||||
|
||||
# Trigger execution quota
|
||||
TRIGGER = auto()
|
||||
|
||||
# Workflow execution quota
|
||||
WORKFLOW = auto()
|
||||
|
||||
UNLIMITED = auto()
|
||||
|
||||
@property
|
||||
def billing_key(self) -> str:
|
||||
"""
|
||||
Get the billing key for the feature.
|
||||
"""
|
||||
match self:
|
||||
case QuotaType.TRIGGER:
|
||||
return "trigger_event"
|
||||
@ -58,152 +19,3 @@ class QuotaType(StrEnum):
|
||||
return "api_rate_limit"
|
||||
case _:
|
||||
raise ValueError(f"Invalid quota type: {self}")
|
||||
|
||||
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Consume quota for the feature.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to consume (default: 1)
|
||||
|
||||
Returns:
|
||||
QuotaCharge with success status and charge_id for refund
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
|
||||
|
||||
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to consume must be greater than 0")
|
||||
|
||||
try:
|
||||
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
|
||||
|
||||
if response.get("result") != "success":
|
||||
logger.warning(
|
||||
"Failed to consume quota for %s, feature %s details: %s",
|
||||
tenant_id,
|
||||
self.value,
|
||||
response.get("detail"),
|
||||
)
|
||||
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
charge_id = response.get("history_id")
|
||||
logger.debug(
|
||||
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
|
||||
amount,
|
||||
self.value,
|
||||
tenant_id,
|
||||
charge_id,
|
||||
)
|
||||
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
# fail-safe: allow request on billing errors
|
||||
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
|
||||
return unlimited()
|
||||
|
||||
def check(self, tenant_id: str, amount: int = 1) -> bool:
|
||||
"""
|
||||
Check if tenant has sufficient quota without consuming.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to check (default: 1)
|
||||
|
||||
Returns:
|
||||
True if quota is sufficient, False otherwise
|
||||
"""
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = self.get_remaining(tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
|
||||
# fail-safe: allow request on billing errors
|
||||
return True
|
||||
|
||||
def refund(self, charge_id: str) -> None:
|
||||
"""
|
||||
Refund quota using charge_id from consume().
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
All errors are logged but silently handled.
|
||||
|
||||
Args:
|
||||
charge_id: The UUID returned from consume()
|
||||
"""
|
||||
try:
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not charge_id:
|
||||
logger.warning("Cannot refund: charge_id is empty")
|
||||
return
|
||||
|
||||
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
|
||||
|
||||
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
|
||||
if response.get("result") == "success":
|
||||
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
|
||||
else:
|
||||
logger.warning("Refund failed for charge_id: %s", charge_id)
|
||||
|
||||
except Exception:
|
||||
# Catch ALL exceptions - refund must never fail
|
||||
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
|
||||
# Don't raise - refund is best-effort and must be silent
|
||||
|
||||
def get_remaining(self, tenant_id: str) -> int:
|
||||
"""
|
||||
Get remaining quota for the tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Remaining quota amount
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
|
||||
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
|
||||
if isinstance(usage_info, dict):
|
||||
return usage_info.get("remaining", 0)
|
||||
# If it returns a simple number, treat it as remaining
|
||||
return int(usage_info) if usage_info else 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
|
||||
return -1
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
"""
|
||||
Return a quota charge for unlimited quota.
|
||||
|
||||
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
|
||||
"""
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
||||
@ -1715,7 +1715,7 @@ class SegmentAttachmentBinding(TypeBase):
|
||||
)
|
||||
|
||||
|
||||
class DocumentSegmentSummary(TypeBase):
|
||||
class DocumentSegmentSummary(Base):
|
||||
__tablename__ = "document_segment_summaries"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"),
|
||||
@ -1725,40 +1725,25 @@ class DocumentSegmentSummary(TypeBase):
|
||||
sa.Index("document_segment_summaries_status_idx", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# corresponds to DocumentSegment.id or parent chunk id
|
||||
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
|
||||
status: Mapped[SummaryStatus] = mapped_column(
|
||||
EnumText(SummaryStatus, length=32),
|
||||
nullable=False,
|
||||
server_default=sa.text("'generating'"),
|
||||
default=SummaryStatus.GENERATING,
|
||||
)
|
||||
error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
|
||||
)
|
||||
error: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@ -10,6 +10,3 @@ This directory holds **optional workspace packages** that plug into Dify’s API
|
||||
|
||||
Provider tests often live next to the package, e.g. `providers/<type>/<backend>/tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).
|
||||
|
||||
## Excluding Providers
|
||||
|
||||
In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-<provider>` and `--group trace-<provider>` to enable specific providers.
|
||||
|
||||
@ -1,78 +0,0 @@
|
||||
# Trace providers
|
||||
|
||||
This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others).
|
||||
|
||||
Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping.
|
||||
|
||||
## Architecture
|
||||
|
||||
| Layer | Location | Role |
|
||||
|--------|----------|------|
|
||||
| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads |
|
||||
| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class |
|
||||
| Your package | `api/providers/trace/trace-<name>/` | Pydantic config + subclass of `BaseTraceInstance` |
|
||||
|
||||
At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype.
|
||||
|
||||
## What you implement
|
||||
|
||||
### 1. Config model (`BaseTracingConfig`)
|
||||
|
||||
Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate.
|
||||
|
||||
Fields fall into two groups used by the manager:
|
||||
|
||||
- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords).
|
||||
- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints).
|
||||
|
||||
List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct.
|
||||
|
||||
### 2. Trace instance (`BaseTraceInstance`)
|
||||
|
||||
Subclass `BaseTraceInstance` and implement:
|
||||
|
||||
```python
|
||||
def trace(self, trace_info: BaseTraceInfo) -> None:
|
||||
...
|
||||
```
|
||||
|
||||
Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including:
|
||||
|
||||
- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace`
|
||||
- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo`
|
||||
- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo`
|
||||
|
||||
You may ignore categories your backend does not support; existing providers often no-op unhandled types.
|
||||
|
||||
Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context.
|
||||
|
||||
### 3. Register in the API core
|
||||
|
||||
Upstream changes are required so Dify knows your provider exists:
|
||||
|
||||
1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`).
|
||||
2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning:
|
||||
- `config_class`: your Pydantic config type
|
||||
- `secret_keys` / `other_keys`: lists of field names as above
|
||||
- `trace_instance`: your `BaseTraceInstance` subclass
|
||||
Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`.
|
||||
|
||||
If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app.
|
||||
|
||||
## Package layout
|
||||
|
||||
Each provider is a normal uv workspace member, for example:
|
||||
|
||||
- `api/providers/trace/trace-<name>/pyproject.toml` — project name `dify-trace-<name>`, dependencies on vendor SDKs
|
||||
- `api/providers/trace/trace-<name>/src/dify_trace_<name>/` — `config.py`, `<name>_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`.
|
||||
|
||||
Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`.
|
||||
|
||||
## Wiring into the `api` workspace
|
||||
|
||||
In `api/pyproject.toml`:
|
||||
|
||||
1. **`[tool.uv.sources]`** — `dify-trace-<name> = { workspace = true }`
|
||||
2. **`[dependency-groups]`** — add `trace-<name> = ["dify-trace-<name>"]` and include `dify-trace-<name>` in `trace-all` if it should ship with the default bundle
|
||||
|
||||
After changing metadata, run **`uv sync`** from `api/`.
|
||||
@ -1,14 +0,0 @@
|
||||
[project]
|
||||
name = "dify-trace-aliyun"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
# versions inherited from parent
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-exporter-otlp-proto-grpc",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-semantic-conventions",
|
||||
]
|
||||
description = "Dify ops tracing provider (Aliyun)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@ -1,32 +0,0 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_url_with_path
|
||||
|
||||
|
||||
class AliyunConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Aliyun tracing config.
|
||||
"""
|
||||
|
||||
app_name: str = "dify_app"
|
||||
license_key: str
|
||||
endpoint: str
|
||||
|
||||
@field_validator("app_name")
|
||||
@classmethod
|
||||
def app_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
@field_validator("license_key")
|
||||
@classmethod
|
||||
def license_key_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("License key cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# aliyun uses two URL formats, which may include a URL path
|
||||
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
@ -1,85 +0,0 @@
|
||||
import pytest
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestAliyunConfig:
|
||||
"""Test cases for AliyunConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Aliyun configuration"""
|
||||
config = AliyunConfig(
|
||||
app_name="test_app",
|
||||
license_key="test_license_key",
|
||||
endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
|
||||
)
|
||||
assert config.app_name == "test_app"
|
||||
assert config.license_key == "test_license_key"
|
||||
assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="test_license")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_app_name_validation_empty(self):
|
||||
"""Test app_name validation with empty value"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
|
||||
)
|
||||
assert config.app_name == "dify_app"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = AliyunConfig(license_key="test_license", endpoint="")
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation preserves path for Aliyun endpoints"""
|
||||
config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_license_key_required(self):
|
||||
"""Test that license_key is required and cannot be empty"""
|
||||
with pytest.raises(ValidationError):
|
||||
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
def test_valid_endpoint_format_examples(self):
|
||||
"""Test valid endpoint format examples from comments"""
|
||||
valid_endpoints = [
|
||||
# cms2.0 public endpoint
|
||||
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
|
||||
# cms2.0 intranet endpoint
|
||||
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
|
||||
# xtrace public endpoint
|
||||
"http://tracing-cn-heyuan.arms.aliyuncs.com",
|
||||
# xtrace intranet endpoint
|
||||
"http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
|
||||
]
|
||||
|
||||
for endpoint in valid_endpoints:
|
||||
config = AliyunConfig(license_key="test_license", endpoint=endpoint)
|
||||
assert config.endpoint == endpoint
|
||||
@ -1,10 +0,0 @@
|
||||
[project]
|
||||
name = "dify-trace-arize-phoenix"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
]
|
||||
description = "Dify ops tracing provider (Arize / Phoenix)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@ -1,45 +0,0 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_url_with_path
|
||||
|
||||
|
||||
class ArizeConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Arize tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
space_id: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://otlp.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
|
||||
|
||||
|
||||
class PhoenixConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Phoenix tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://app.phoenix.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://app.phoenix.arize.com")
|
||||
@ -1,88 +0,0 @@
|
||||
import pytest
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestArizeConfig:
|
||||
"""Test cases for ArizeConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Arize configuration"""
|
||||
config = ArizeConfig(
|
||||
api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.space_id == "test_space"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = ArizeConfig()
|
||||
assert config.api_key is None
|
||||
assert config.space_id is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = ArizeConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_project_validation_none(self):
|
||||
"""Test project validation with None value"""
|
||||
config = ArizeConfig(project=None)
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_empty(self):
|
||||
"""Test endpoint validation with empty value"""
|
||||
config = ArizeConfig(endpoint="")
|
||||
assert config.endpoint == "https://otlp.arize.com"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
|
||||
assert config.endpoint == "https://custom.arize.com"
|
||||
|
||||
def test_endpoint_validation_invalid_scheme(self):
|
||||
"""Test endpoint validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="ftp://invalid.com")
|
||||
|
||||
def test_endpoint_validation_no_scheme(self):
|
||||
"""Test endpoint validation rejects URLs without scheme"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
ArizeConfig(endpoint="invalid.com")
|
||||
|
||||
|
||||
class TestPhoenixConfig:
|
||||
"""Test cases for PhoenixConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Phoenix configuration"""
|
||||
config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.phoenix.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = PhoenixConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.endpoint == "https://app.phoenix.arize.com"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = PhoenixConfig(project="")
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation with path"""
|
||||
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
|
||||
assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
|
||||
|
||||
def test_endpoint_validation_without_path(self):
|
||||
"""Test endpoint validation without path"""
|
||||
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
|
||||
assert config.endpoint == "https://app.phoenix.arize.com"
|
||||
@ -1,10 +0,0 @@
|
||||
[project]
|
||||
name = "dify-trace-langfuse"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"langfuse>=4.2.0,<5.0.0",
|
||||
]
|
||||
description = "Dify ops tracing provider (Langfuse)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@ -1,19 +0,0 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_url_with_path
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langfuse tracing config.
|
||||
"""
|
||||
|
||||
public_key: str
|
||||
secret_key: str
|
||||
host: str = "https://api.langfuse.com"
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://api.langfuse.com")
|
||||
@ -1,42 +0,0 @@
|
||||
import pytest
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestLangfuseConfig:
|
||||
"""Test cases for LangfuseConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Langfuse configuration"""
|
||||
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
|
||||
assert config.public_key == "public_key"
|
||||
assert config.secret_key == "secret_key"
|
||||
assert config.host == "https://custom.langfuse.com"
|
||||
|
||||
def test_valid_config_with_path(self):
|
||||
host = "https://custom.langfuse.com/api/v1"
|
||||
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
|
||||
assert config.public_key == "public_key"
|
||||
assert config.secret_key == "secret_key"
|
||||
assert config.host == host
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(public_key="public")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangfuseConfig(secret_key="secret")
|
||||
|
||||
def test_host_validation_empty(self):
|
||||
"""Test host validation with empty value"""
|
||||
config = LangfuseConfig(public_key="public", secret_key="secret", host="")
|
||||
assert config.host == "https://api.langfuse.com"
|
||||
@ -1,10 +0,0 @@
|
||||
[project]
|
||||
name = "dify-trace-langsmith"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"langsmith~=0.7.30",
|
||||
]
|
||||
description = "Dify ops tracing provider (LangSmith)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@ -1,20 +0,0 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_url
|
||||
|
||||
|
||||
class LangSmithConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Langsmith tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
project: str
|
||||
endpoint: str = "https://api.smith.langchain.com"
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# LangSmith only allows HTTPS
|
||||
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
|
||||
@ -1,35 +0,0 @@
|
||||
import pytest
|
||||
from dify_trace_langsmith.config import LangSmithConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestLangSmithConfig:
|
||||
"""Test cases for LangSmithConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid LangSmith configuration"""
|
||||
config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.smith.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = LangSmithConfig(api_key="key", project="project")
|
||||
assert config.endpoint == "https://api.smith.langchain.com"
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
LangSmithConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
|
||||
@ -1,10 +0,0 @@
|
||||
[project]
|
||||
name = "dify-trace-mlflow"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"mlflow-skinny>=3.11.1",
|
||||
]
|
||||
description = "Dify ops tracing provider (MLflow / Databricks)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@ -1,46 +0,0 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_integer_id, validate_url_with_path
|
||||
|
||||
|
||||
class MLflowConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for MLflow tracing config.
|
||||
"""
|
||||
|
||||
tracking_uri: str = "http://localhost:5000"
|
||||
experiment_id: str = "0" # Default experiment id in MLflow is 0
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
@field_validator("tracking_uri")
|
||||
@classmethod
|
||||
def tracking_uri_validator(cls, v, info: ValidationInfo):
|
||||
if isinstance(v, str) and v.startswith("databricks"):
|
||||
raise ValueError(
|
||||
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
|
||||
)
|
||||
return validate_url_with_path(v, "http://localhost:5000")
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
class DatabricksConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Databricks (Databricks-managed MLflow) tracing config.
|
||||
"""
|
||||
|
||||
experiment_id: str
|
||||
host: str
|
||||
client_id: str | None = None
|
||||
client_secret: str | None = None
|
||||
personal_access_token: str | None = None
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
@ -1,10 +0,0 @@
|
||||
[project]
|
||||
name = "dify-trace-opik"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"opik~=1.11.2",
|
||||
]
|
||||
description = "Dify ops tracing provider (Opik)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@ -1,25 +0,0 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_url_with_path
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Opik tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
workspace: str | None = None
|
||||
url: str = "https://www.comet.com/opik/api/"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "Default Project")
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
|
||||
@ -1,48 +0,0 @@
|
||||
import pytest
|
||||
from dify_trace_opik.config import OpikConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestOpikConfig:
|
||||
"""Test cases for OpikConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Opik configuration"""
|
||||
config = OpikConfig(
|
||||
api_key="test_key",
|
||||
project="test_project",
|
||||
workspace="test_workspace",
|
||||
url="https://custom.comet.com/opik/api/",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.project == "test_project"
|
||||
assert config.workspace == "test_workspace"
|
||||
assert config.url == "https://custom.comet.com/opik/api/"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = OpikConfig()
|
||||
assert config.api_key is None
|
||||
assert config.project is None
|
||||
assert config.workspace is None
|
||||
assert config.url == "https://www.comet.com/opik/api/"
|
||||
|
||||
def test_project_validation_empty(self):
|
||||
"""Test project validation with empty value"""
|
||||
config = OpikConfig(project="")
|
||||
assert config.project == "Default Project"
|
||||
|
||||
def test_url_validation_empty(self):
|
||||
"""Test URL validation with empty value"""
|
||||
config = OpikConfig(url="")
|
||||
assert config.url == "https://www.comet.com/opik/api/"
|
||||
|
||||
def test_url_validation_missing_suffix(self):
|
||||
"""Test URL validation requires /api/ suffix"""
|
||||
with pytest.raises(ValidationError, match="URL should end with /api/"):
|
||||
OpikConfig(url="https://custom.comet.com/opik/")
|
||||
|
||||
def test_url_validation_invalid_scheme(self):
|
||||
"""Test URL validation rejects invalid schemes"""
|
||||
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
|
||||
OpikConfig(url="ftp://custom.comet.com/opik/api/")
|
||||
@ -1,14 +0,0 @@
|
||||
[project]
|
||||
name = "dify-trace-tencent"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
# versions inherited from parent
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-exporter-otlp-proto-grpc",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-semantic-conventions",
|
||||
]
|
||||
description = "Dify ops tracing provider (Tencent APM)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@ -1,30 +0,0 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
|
||||
|
||||
class TencentConfig(BaseTracingConfig):
|
||||
"""
|
||||
Tencent APM tracing config
|
||||
"""
|
||||
|
||||
token: str
|
||||
endpoint: str
|
||||
service_name: str
|
||||
|
||||
@field_validator("token")
|
||||
@classmethod
|
||||
def token_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Token cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
|
||||
|
||||
@field_validator("service_name")
|
||||
@classmethod
|
||||
def service_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
@ -1,10 +0,0 @@
|
||||
[project]
|
||||
name = "dify-trace-weave"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"weave>=0.52.36",
|
||||
]
|
||||
description = "Dify ops tracing provider (Weave)."
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@ -1,29 +0,0 @@
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.utils import validate_url
|
||||
|
||||
|
||||
class WeaveConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Weave tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
entity: str | None = None
|
||||
project: str
|
||||
endpoint: str = "https://trace.wandb.ai"
|
||||
host: str | None = None
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# Weave only allows HTTPS for endpoint
|
||||
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
if v is not None and v.strip() != "":
|
||||
return validate_url(v, v, allowed_schemes=("https", "http"))
|
||||
return v
|
||||
@ -1,61 +0,0 @@
|
||||
import pytest
|
||||
from dify_trace_weave.config import WeaveConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
class TestWeaveConfig:
|
||||
"""Test cases for WeaveConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Weave configuration"""
|
||||
config = WeaveConfig(
|
||||
api_key="test_key",
|
||||
entity="test_entity",
|
||||
project="test_project",
|
||||
endpoint="https://custom.wandb.ai",
|
||||
host="https://custom.host.com",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.entity == "test_entity"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.wandb.ai"
|
||||
assert config.host == "https://custom.host.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = WeaveConfig(api_key="key", project="project")
|
||||
assert config.entity is None
|
||||
assert config.endpoint == "https://trace.wandb.ai"
|
||||
assert config.host is None
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
|
||||
|
||||
def test_host_validation_optional(self):
|
||||
"""Test host validation is optional but validates when provided"""
|
||||
config = WeaveConfig(api_key="key", project="project", host=None)
|
||||
assert config.host is None
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="")
|
||||
assert config.host == ""
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
|
||||
assert config.host == "https://valid.host.com"
|
||||
|
||||
def test_host_validation_invalid_scheme(self):
|
||||
"""Test host validation rejects invalid schemes when provided"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
|
||||
@ -32,6 +32,9 @@ dependencies = [
|
||||
"flask-restx>=1.3.2,<2.0.0",
|
||||
"google-cloud-aiplatform>=1.147.0,<2.0.0",
|
||||
"httpx[socks]>=0.28.1,<1.0.0",
|
||||
"langfuse>=4.2.0,<5.0.0",
|
||||
"langsmith>=0.7.31,<1.0.0",
|
||||
"mlflow-skinny>=3.11.1,<4.0.0",
|
||||
"opentelemetry-distro>=0.62b0,<1.0.0",
|
||||
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
|
||||
"opentelemetry-instrumentation-flask>=0.62b0,<1.0.0",
|
||||
@ -41,12 +44,15 @@ dependencies = [
|
||||
"opentelemetry-propagator-b3>=1.41.0,<2.0.0",
|
||||
"readabilipy>=0.3.0,<1.0.0",
|
||||
"resend>=2.27.0,<3.0.0",
|
||||
"weave>=0.52.36,<1.0.0",
|
||||
|
||||
# Emerging: newer and fast-moving, use compatible pins
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"fastopenapi[flask]~=0.7.0",
|
||||
"graphon~=0.1.2",
|
||||
"httpx-sse~=0.4.0",
|
||||
"json-repair~=0.59.2",
|
||||
"opik~=1.11.2",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
@ -55,8 +61,8 @@ dependencies = [
|
||||
packages = []
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["providers/vdb/*", "providers/trace/*"]
|
||||
exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"]
|
||||
members = ["providers/vdb/*"]
|
||||
exclude = ["providers/vdb/__pycache__"]
|
||||
|
||||
[tool.uv.sources]
|
||||
dify-vdb-alibabacloud-mysql = { workspace = true }
|
||||
@ -89,17 +95,9 @@ dify-vdb-upstash = { workspace = true }
|
||||
dify-vdb-vastbase = { workspace = true }
|
||||
dify-vdb-vikingdb = { workspace = true }
|
||||
dify-vdb-weaviate = { workspace = true }
|
||||
dify-trace-aliyun = { workspace = true }
|
||||
dify-trace-arize-phoenix = { workspace = true }
|
||||
dify-trace-langfuse = { workspace = true }
|
||||
dify-trace-langsmith = { workspace = true }
|
||||
dify-trace-mlflow = { workspace = true }
|
||||
dify-trace-opik = { workspace = true }
|
||||
dify-trace-tencent = { workspace = true }
|
||||
dify-trace-weave = { workspace = true }
|
||||
|
||||
[tool.uv]
|
||||
default-groups = ["storage", "tools", "vdb-all", "trace-all"]
|
||||
default-groups = ["storage", "tools", "vdb-all"]
|
||||
package = false
|
||||
override-dependencies = [
|
||||
"pyarrow>=18.0.0",
|
||||
@ -156,17 +154,17 @@ dev = [
|
||||
"types-six>=1.17.0.20260408",
|
||||
"types-tensorflow>=2.18.0.20260408",
|
||||
"types-tqdm>=4.67.3.20260408",
|
||||
"types-ujson>=5.10.0.20250822",
|
||||
"types-ujson>=5.10.0",
|
||||
"boto3-stubs>=1.42.88",
|
||||
"types-jmespath>=1.1.0.20260408",
|
||||
"hypothesis>=6.151.12",
|
||||
"types_pyOpenSSL>=24.1.0.20240722",
|
||||
"types_pyOpenSSL>=24.1.0",
|
||||
"types_cffi>=2.0.0.20260408",
|
||||
"types_setuptools>=82.0.0.20260408",
|
||||
"pandas-stubs>=3.0.0",
|
||||
"scipy-stubs>=1.15.3.0",
|
||||
"types-python-http-client>=3.3.7.20260408",
|
||||
"import-linter>=2.11",
|
||||
"import-linter>=2.3",
|
||||
"types-redis>=4.6.0.20241004",
|
||||
"celery-types>=0.23.0",
|
||||
"mypy>=1.20.1",
|
||||
@ -268,25 +266,6 @@ vdb-weaviate = ["dify-vdb-weaviate"]
|
||||
# Optional client used by some tests / integrations (not a vector backend plugin)
|
||||
vdb-xinference = ["xinference-client>=2.4.0"]
|
||||
|
||||
trace-all = [
|
||||
"dify-trace-aliyun",
|
||||
"dify-trace-arize-phoenix",
|
||||
"dify-trace-langfuse",
|
||||
"dify-trace-langsmith",
|
||||
"dify-trace-mlflow",
|
||||
"dify-trace-opik",
|
||||
"dify-trace-tencent",
|
||||
"dify-trace-weave",
|
||||
]
|
||||
trace-aliyun = ["dify-trace-aliyun"]
|
||||
trace-arize-phoenix = ["dify-trace-arize-phoenix"]
|
||||
trace-langfuse = ["dify-trace-langfuse"]
|
||||
trace-langsmith = ["dify-trace-langsmith"]
|
||||
trace-mlflow = ["dify-trace-mlflow"]
|
||||
trace-opik = ["dify-trace-opik"]
|
||||
trace-tencent = ["dify-trace-tencent"]
|
||||
trace-weave = ["dify-trace-weave"]
|
||||
|
||||
[tool.pyrefly]
|
||||
project-includes = ["."]
|
||||
project-excludes = [".venv", "migrations/"]
|
||||
|
||||
@ -34,12 +34,12 @@ core/external_data_tool/api/api.py
|
||||
core/llm_generator/llm_generator.py
|
||||
core/llm_generator/output_parser/structured_output.py
|
||||
core/mcp/mcp_client.py
|
||||
providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py
|
||||
providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py
|
||||
providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
|
||||
core/ops/aliyun_trace/data_exporter/traceclient.py
|
||||
core/ops/arize_phoenix_trace/arize_phoenix_trace.py
|
||||
core/ops/mlflow_trace/mlflow_trace.py
|
||||
core/ops/ops_trace_manager.py
|
||||
providers/trace/trace-tencent/src/dify_trace_tencent/client.py
|
||||
providers/trace/trace-tencent/src/dify_trace_tencent/utils.py
|
||||
core/ops/tencent_trace/client.py
|
||||
core/ops/tencent_trace/utils.py
|
||||
core/plugin/backwards_invocation/base.py
|
||||
core/plugin/backwards_invocation/model.py
|
||||
core/prompt/utils/extract_thread_messages.py
|
||||
|
||||
@ -5,8 +5,7 @@
|
||||
".venv",
|
||||
"migrations/",
|
||||
"core/rag",
|
||||
"providers/vdb/",
|
||||
"providers/trace/*/tests",
|
||||
"providers/",
|
||||
],
|
||||
"typeCheckingMode": "strict",
|
||||
"allowedUntypedLibraries": [
|
||||
|
||||
@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit
|
||||
from core.app.features.rate_limiting.rate_limit import rate_limit_context
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.db import session_factory
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from enums.quota_type import QuotaType
|
||||
from extensions.otel import AppGenerateHandler, trace_span
|
||||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.workflow_service import WorkflowService
|
||||
from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task
|
||||
|
||||
@ -106,7 +107,7 @@ class AppGenerateService:
|
||||
quota_charge = unlimited()
|
||||
if dify_config.BILLING_ENABLED:
|
||||
try:
|
||||
quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id)
|
||||
except QuotaExceededError:
|
||||
raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}")
|
||||
|
||||
@ -116,6 +117,7 @@ class AppGenerateService:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
quota_charge.commit()
|
||||
effective_mode = (
|
||||
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
|
||||
)
|
||||
|
||||
@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -131,9 +132,10 @@ class AsyncWorkflowService:
|
||||
trigger_log = trigger_log_repo.create(trigger_log)
|
||||
session.commit()
|
||||
|
||||
# 7. Check and consume quota
|
||||
# 7. Reserve quota (commit after successful dispatch)
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
QuotaType.WORKFLOW.consume(trigger_data.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id)
|
||||
except QuotaExceededError as e:
|
||||
# Update trigger log status
|
||||
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
|
||||
@ -153,13 +155,18 @@ class AsyncWorkflowService:
|
||||
# 9. Dispatch to appropriate queue
|
||||
task_data_dict = task_data.model_dump(mode="json")
|
||||
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
try:
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
|
||||
# 10. Update trigger log with task info
|
||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||
|
||||
@ -32,6 +32,50 @@ class SubscriptionPlan(TypedDict):
|
||||
expiration_date: int
|
||||
|
||||
|
||||
class QuotaReserveResult(TypedDict):
|
||||
reservation_id: str
|
||||
available: int
|
||||
reserved: int
|
||||
|
||||
|
||||
class QuotaCommitResult(TypedDict):
|
||||
available: int
|
||||
reserved: int
|
||||
refunded: int
|
||||
|
||||
|
||||
class QuotaReleaseResult(TypedDict):
|
||||
available: int
|
||||
reserved: int
|
||||
released: int
|
||||
|
||||
|
||||
_quota_reserve_adapter = TypeAdapter(QuotaReserveResult)
|
||||
_quota_commit_adapter = TypeAdapter(QuotaCommitResult)
|
||||
_quota_release_adapter = TypeAdapter(QuotaReleaseResult)
|
||||
|
||||
|
||||
class _TenantFeatureQuota(TypedDict):
|
||||
usage: int
|
||||
limit: int
|
||||
reset_date: NotRequired[int]
|
||||
|
||||
|
||||
class TenantFeatureQuotaInfo(TypedDict):
|
||||
"""Response of /quota/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Same convention as BillingInfo: billing may return int fields as str,
|
||||
always keep non-strict mode to auto-coerce.
|
||||
"""
|
||||
|
||||
trigger_event: _TenantFeatureQuota
|
||||
api_rate_limit: _TenantFeatureQuota
|
||||
|
||||
|
||||
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
|
||||
|
||||
|
||||
class _BillingQuota(TypedDict):
|
||||
size: int
|
||||
limit: int
|
||||
@ -149,11 +193,63 @@ class BillingService:
|
||||
|
||||
@classmethod
|
||||
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
||||
"""Deprecated: Use get_quota_info instead."""
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params)
|
||||
return usage_info
|
||||
|
||||
@classmethod
|
||||
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
|
||||
params = {"tenant_id": tenant_id}
|
||||
return _tenant_feature_quota_info_adapter.validate_python(
|
||||
cls._send_request("GET", "/quota/info", params=params)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def quota_reserve(
|
||||
cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None
|
||||
) -> QuotaReserveResult:
|
||||
"""Reserve quota before task execution."""
|
||||
payload: dict = {
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"request_id": request_id,
|
||||
"amount": amount,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload))
|
||||
|
||||
@classmethod
|
||||
def quota_commit(
|
||||
cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None
|
||||
) -> QuotaCommitResult:
|
||||
"""Commit a reservation with actual consumption."""
|
||||
payload: dict = {
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"reservation_id": reservation_id,
|
||||
"actual_amount": actual_amount,
|
||||
}
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload))
|
||||
|
||||
@classmethod
|
||||
def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult:
|
||||
"""Release a reservation (cancel, return frozen quota)."""
|
||||
return _quota_release_adapter.validate_python(
|
||||
cls._send_request(
|
||||
"POST",
|
||||
"/quota/release",
|
||||
json={
|
||||
"tenant_id": tenant_id,
|
||||
"feature_key": feature_key,
|
||||
"reservation_id": reservation_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
@ -283,7 +283,7 @@ class FeatureService:
|
||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
|
||||
features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id)
|
||||
features_usage_info = BillingService.get_quota_info(tenant_id)
|
||||
|
||||
features.billing.enabled = billing_info["enabled"]
|
||||
features.billing.subscription.plan = billing_info["subscription"]["plan"]
|
||||
|
||||
233
api/services/quota_service.py
Normal file
233
api/services/quota_service.py
Normal file
@ -0,0 +1,233 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enums.quota_type import QuotaType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota reservation (Reserve phase).
|
||||
|
||||
Lifecycle:
|
||||
charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id)
|
||||
try:
|
||||
do_work()
|
||||
charge.commit() # Confirm consumption
|
||||
except:
|
||||
charge.refund() # Release frozen quota
|
||||
|
||||
If neither commit() nor refund() is called, the billing system's
|
||||
cleanup CronJob will auto-release the reservation within ~75 seconds.
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None # reservation_id
|
||||
_quota_type: QuotaType
|
||||
_tenant_id: str | None = None
|
||||
_feature_key: str | None = None
|
||||
_amount: int = 0
|
||||
_committed: bool = field(default=False, repr=False)
|
||||
|
||||
def commit(self, actual_amount: int | None = None) -> None:
|
||||
"""
|
||||
Confirm the consumption with actual amount.
|
||||
|
||||
Args:
|
||||
actual_amount: Actual amount consumed. Defaults to the reserved amount.
|
||||
If less than reserved, the difference is refunded automatically.
|
||||
"""
|
||||
if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||
return
|
||||
|
||||
try:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
amount = actual_amount if actual_amount is not None else self._amount
|
||||
BillingService.quota_commit(
|
||||
tenant_id=self._tenant_id,
|
||||
feature_key=self._feature_key,
|
||||
reservation_id=self.charge_id,
|
||||
actual_amount=amount,
|
||||
)
|
||||
self._committed = True
|
||||
logger.debug(
|
||||
"Committed %s quota for tenant %s, reservation_id: %s, amount: %d",
|
||||
self._quota_type,
|
||||
self._tenant_id,
|
||||
self.charge_id,
|
||||
amount,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id)
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Release the reserved quota (cancel the charge).
|
||||
|
||||
Safe to call even if:
|
||||
- charge failed or was disabled (charge_id is None)
|
||||
- already committed (Release after Commit is a no-op)
|
||||
- already refunded (idempotent)
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if not self.charge_id or not self._tenant_id or not self._feature_key:
|
||||
return
|
||||
|
||||
QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key)
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
from enums.quota_type import QuotaType
|
||||
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
||||
|
||||
class QuotaService:
|
||||
"""Orchestrates quota reserve / commit / release lifecycle via BillingService."""
|
||||
|
||||
@staticmethod
|
||||
def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Reserve + immediate Commit (one-shot mode).
|
||||
|
||||
The returned QuotaCharge supports .refund() which calls Release.
|
||||
For two-phase usage (e.g. streaming), use reserve() directly.
|
||||
"""
|
||||
charge = QuotaService.reserve(quota_type, tenant_id, amount)
|
||||
if charge.success and charge.charge_id:
|
||||
charge.commit()
|
||||
return charge
|
||||
|
||||
@staticmethod
|
||||
def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Reserve quota before task execution (Reserve phase only).
|
||||
|
||||
The caller MUST call charge.commit() after the task succeeds,
|
||||
or charge.refund() if the task fails.
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type)
|
||||
|
||||
logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to reserve must be greater than 0")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
feature_key = quota_type.billing_key
|
||||
|
||||
try:
|
||||
reserve_resp = BillingService.quota_reserve(
|
||||
tenant_id=tenant_id,
|
||||
feature_key=feature_key,
|
||||
request_id=request_id,
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
reservation_id = reserve_resp.get("reservation_id")
|
||||
if not reservation_id:
|
||||
logger.warning(
|
||||
"Reserve returned no reservation_id for %s, feature %s, response: %s",
|
||||
tenant_id,
|
||||
quota_type.value,
|
||||
reserve_resp,
|
||||
)
|
||||
raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
logger.debug(
|
||||
"Reserved %d %s quota for tenant %s, reservation_id: %s",
|
||||
amount,
|
||||
quota_type.value,
|
||||
tenant_id,
|
||||
reservation_id,
|
||||
)
|
||||
return QuotaCharge(
|
||||
success=True,
|
||||
charge_id=reservation_id,
|
||||
_quota_type=quota_type,
|
||||
_tenant_id=tenant_id,
|
||||
_feature_key=feature_key,
|
||||
_amount=amount,
|
||||
)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return unlimited()
|
||||
|
||||
@staticmethod
|
||||
def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool:
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = QuotaService.get_remaining(quota_type, tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None:
|
||||
"""Release a reservation. Guarantees no exceptions."""
|
||||
try:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not reservation_id:
|
||||
return
|
||||
|
||||
logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id)
|
||||
BillingService.quota_release(
|
||||
tenant_id=tenant_id,
|
||||
feature_key=feature_key,
|
||||
reservation_id=reservation_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to release quota, reservation_id: %s", reservation_id)
|
||||
|
||||
@staticmethod
|
||||
def get_remaining(quota_type: QuotaType, tenant_id: str) -> int:
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_quota_info(tenant_id)
|
||||
if isinstance(usage_info, dict):
|
||||
feature_info = usage_info.get(quota_type.billing_key, {})
|
||||
if isinstance(feature_info, dict):
|
||||
limit = feature_info.get("limit", 0)
|
||||
usage = feature_info.get("usage", 0)
|
||||
if limit == -1:
|
||||
return -1
|
||||
return max(0, limit - usage)
|
||||
return 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value)
|
||||
return -1
|
||||
@ -349,6 +349,7 @@ class SummaryIndexService:
|
||||
summary_record_id,
|
||||
)
|
||||
summary_record_in_session = DocumentSegmentSummary(
|
||||
id=summary_record_id, # Use the same ID if available
|
||||
dataset_id=dataset.id,
|
||||
document_id=segment.document_id,
|
||||
chunk_id=segment.id,
|
||||
@ -359,9 +360,6 @@ class SummaryIndexService:
|
||||
status=SummaryStatus.COMPLETED,
|
||||
enabled=True,
|
||||
)
|
||||
if summary_record_in_session is None:
|
||||
raise RuntimeError("summary_record_in_session should not be None at this point")
|
||||
summary_record_in_session.id = summary_record_id
|
||||
session.add(summary_record_in_session)
|
||||
logger.info(
|
||||
"Created new summary record (id=%s) for segment %s after vectorization",
|
||||
|
||||
@ -38,6 +38,7 @@ from models.workflow import Workflow
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
|
||||
@ -819,9 +820,9 @@ class WebhookService:
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# consume quota before triggering workflow execution
|
||||
# reserve quota before triggering workflow execution
|
||||
try:
|
||||
QuotaType.TRIGGER.consume(webhook_trigger.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id)
|
||||
logger.info(
|
||||
@ -832,11 +833,16 @@ class WebhookService:
|
||||
raise
|
||||
|
||||
# Trigger workflow execution asynchronously
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
try:
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session,
|
||||
end_user,
|
||||
trigger_data,
|
||||
)
|
||||
quota_charge.commit()
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id)
|
||||
|
||||
@ -27,7 +27,7 @@ from core.trigger.entities.entities import TriggerProviderEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from enums.quota_type import QuotaType
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from models.enums import (
|
||||
AppTriggerType,
|
||||
@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom,
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_request_service import TriggerHttpRequestCachingService
|
||||
@ -298,10 +299,10 @@ def dispatch_triggered_workflow(
|
||||
icon_dark_filename=trigger_entity.identity.icon_dark or "",
|
||||
)
|
||||
|
||||
# consume quota before invoking trigger
|
||||
# reserve quota before invoking trigger
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id)
|
||||
logger.info(
|
||||
@ -387,6 +388,7 @@ def dispatch_triggered_workflow(
|
||||
raise ValueError(f"End user not found for app {plugin_trigger.app_id}")
|
||||
|
||||
AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data)
|
||||
quota_charge.commit()
|
||||
dispatched_count += 1
|
||||
logger.info(
|
||||
"Triggered workflow for app %s with trigger event %s",
|
||||
|
||||
@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import (
|
||||
ScheduleNotFoundError,
|
||||
TenantOwnerNotFoundError,
|
||||
)
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from enums.quota_type import QuotaType
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService, unlimited
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
from services.workflow.entities import ScheduleTriggerData
|
||||
@ -43,7 +44,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
|
||||
quota_charge = unlimited()
|
||||
try:
|
||||
quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id)
|
||||
quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id)
|
||||
except QuotaExceededError:
|
||||
AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id)
|
||||
logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id)
|
||||
@ -61,6 +62,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
||||
tenant_id=schedule.tenant_id,
|
||||
),
|
||||
)
|
||||
quota_charge.commit()
|
||||
logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id)
|
||||
except Exception as e:
|
||||
quota_charge.refund()
|
||||
|
||||
@ -36,12 +36,19 @@ class TestAppGenerateService:
|
||||
) as mock_message_based_generator,
|
||||
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
|
||||
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
|
||||
patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config,
|
||||
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
|
||||
):
|
||||
# Setup default mock returns for billing service
|
||||
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
|
||||
"result": "success",
|
||||
"history_id": "test_history_id",
|
||||
mock_billing_service.quota_reserve.return_value = {
|
||||
"reservation_id": "test-reservation-id",
|
||||
"available": 100,
|
||||
"reserved": 1,
|
||||
}
|
||||
mock_billing_service.quota_commit.return_value = {
|
||||
"available": 99,
|
||||
"reserved": 0,
|
||||
"refunded": 0,
|
||||
}
|
||||
|
||||
# Setup default mock returns for workflow service
|
||||
@ -101,6 +108,8 @@ class TestAppGenerateService:
|
||||
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
|
||||
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
|
||||
mock_quota_dify_config.BILLING_ENABLED = False
|
||||
|
||||
mock_global_dify_config.BILLING_ENABLED = False
|
||||
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
|
||||
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
@ -118,6 +127,7 @@ class TestAppGenerateService:
|
||||
"message_based_generator": mock_message_based_generator,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"dify_config": mock_dify_config,
|
||||
"quota_dify_config": mock_quota_dify_config,
|
||||
"global_dify_config": mock_global_dify_config,
|
||||
}
|
||||
|
||||
@ -465,6 +475,7 @@ class TestAppGenerateService:
|
||||
|
||||
# Set BILLING_ENABLED to True for this test
|
||||
mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True
|
||||
mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True
|
||||
|
||||
# Setup test arguments
|
||||
@ -478,8 +489,10 @@ class TestAppGenerateService:
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify billing service was called to consume quota
|
||||
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
|
||||
# Verify billing two-phase quota (reserve + commit)
|
||||
billing = mock_external_service_dependencies["billing_service"]
|
||||
billing.quota_reserve.assert_called_once()
|
||||
billing.quota_commit.assert_called_once()
|
||||
|
||||
def test_generate_with_invalid_app_mode(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
|
||||
@ -1,650 +0,0 @@
|
||||
"""Testcontainers integration tests for SQL-backed DocumentService paths."""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from unittest.mock import create_autospec, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0)
|
||||
|
||||
|
||||
class DocumentServiceIntegrationFactory:
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
tenant_id: str | None = None,
|
||||
created_by: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> Dataset:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id or str(uuid4()),
|
||||
name=name or f"dataset-{uuid4()}",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by or str(uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_document(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
dataset: Dataset,
|
||||
name: str = "doc.txt",
|
||||
position: int = 1,
|
||||
tenant_id: str | None = None,
|
||||
indexing_status: str = IndexingStatus.COMPLETED,
|
||||
enabled: bool = True,
|
||||
archived: bool = False,
|
||||
is_paused: bool = False,
|
||||
need_summary: bool = False,
|
||||
doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
|
||||
batch: str | None = None,
|
||||
data_source_type: str = DataSourceType.UPLOAD_FILE,
|
||||
data_source_info: dict | None = None,
|
||||
created_by: str | None = None,
|
||||
) -> Document:
|
||||
document = Document(
|
||||
tenant_id=tenant_id or dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=position,
|
||||
data_source_type=data_source_type,
|
||||
data_source_info=json.dumps(data_source_info or {}),
|
||||
batch=batch or f"batch-{uuid4()}",
|
||||
name=name,
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by or dataset.created_by,
|
||||
doc_form=doc_form,
|
||||
)
|
||||
document.indexing_status = indexing_status
|
||||
document.enabled = enabled
|
||||
document.archived = archived
|
||||
document.is_paused = is_paused
|
||||
document.need_summary = need_summary
|
||||
if indexing_status == IndexingStatus.COMPLETED:
|
||||
document.completed_at = FIXED_UPLOAD_CREATED_AT
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def create_upload_file(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
file_id: str | None = None,
|
||||
name: str = "source.txt",
|
||||
) -> UploadFile:
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=f"uploads/{uuid4()}",
|
||||
name=name,
|
||||
size=128,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=created_by,
|
||||
created_at=FIXED_UPLOAD_CREATED_AT,
|
||||
used=False,
|
||||
)
|
||||
if file_id:
|
||||
upload_file.id = file_id
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
return upload_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user_mock():
|
||||
with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
|
||||
current_user.id = str(uuid4())
|
||||
current_user.current_tenant_id = str(uuid4())
|
||||
current_user.current_role = None
|
||||
yield current_user
|
||||
|
||||
|
||||
def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
assert DocumentService.get_document(dataset.id, None) is None
|
||||
|
||||
|
||||
def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset)
|
||||
|
||||
result = DocumentService.get_document(dataset.id, document.id)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == document.id
|
||||
|
||||
|
||||
def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
result = DocumentService.get_documents_by_ids(dataset.id, [])
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt")
|
||||
doc_b = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
name="b.txt",
|
||||
position=2,
|
||||
)
|
||||
|
||||
result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id])
|
||||
|
||||
assert {document.id for document in result} == {doc_a.id, doc_b.id}
|
||||
|
||||
|
||||
def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
assert DocumentService.update_documents_need_summary(dataset.id, []) == 0
|
||||
|
||||
|
||||
def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
paragraph_doc = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
need_summary=True,
|
||||
)
|
||||
qa_doc = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
need_summary=True,
|
||||
doc_form=IndexStructureType.QA_INDEX,
|
||||
)
|
||||
|
||||
updated_count = DocumentService.update_documents_need_summary(
|
||||
dataset.id,
|
||||
[paragraph_doc.id, qa_doc.id],
|
||||
need_summary=False,
|
||||
)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
refreshed_paragraph = db_session_with_containers.get(Document, paragraph_doc.id)
|
||||
refreshed_qa = db_session_with_containers.get(Document, qa_doc.id)
|
||||
assert updated_count == 1
|
||||
assert refreshed_paragraph is not None
|
||||
assert refreshed_qa is not None
|
||||
assert refreshed_paragraph.need_summary is False
|
||||
assert refreshed_qa.need_summary is True
|
||||
|
||||
|
||||
def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file.id},
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url:
|
||||
result = DocumentService.get_document_download_url(document)
|
||||
|
||||
assert result == "signed-url"
|
||||
get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
|
||||
def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_type=DataSourceType.WEBSITE_CRAWL,
|
||||
data_source_info={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="invalid source"):
|
||||
DocumentService._get_upload_file_id_for_upload_file_document(
|
||||
document,
|
||||
invalid_source_message="invalid source",
|
||||
missing_file_message="missing file",
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={},
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="missing file"):
|
||||
DocumentService._get_upload_file_id_for_upload_file_document(
|
||||
document,
|
||||
invalid_source_message="invalid source",
|
||||
missing_file_message="missing file",
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": 99},
|
||||
)
|
||||
|
||||
result = DocumentService._get_upload_file_id_for_upload_file_document(
|
||||
document,
|
||||
invalid_source_message="invalid source",
|
||||
missing_file_message="missing file",
|
||||
)
|
||||
|
||||
assert result == "99"
|
||||
|
||||
|
||||
def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": "missing-file"},
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}):
|
||||
with pytest.raises(NotFound, match="Uploaded file not found"):
|
||||
DocumentService._get_upload_file_for_upload_file_document(document)
|
||||
|
||||
|
||||
def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file.id},
|
||||
)
|
||||
|
||||
result = DocumentService._get_upload_file_for_upload_file_document(document)
|
||||
|
||||
assert result.id == upload_file.id
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
with pytest.raises(NotFound, match="Document not found"):
|
||||
DocumentService._get_upload_files_by_document_id_for_zip_download(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[str(uuid4())],
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
tenant_id=str(uuid4()),
|
||||
data_source_info={"upload_file_id": upload_file.id},
|
||||
)
|
||||
|
||||
with pytest.raises(Forbidden, match="No permission"):
|
||||
DocumentService._get_upload_files_by_document_id_for_zip_download(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[document.id],
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": str(uuid4())},
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"):
|
||||
DocumentService._get_upload_files_by_document_id_for_zip_download(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[document.id],
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="a.txt",
|
||||
)
|
||||
upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="b.txt",
|
||||
)
|
||||
document_a = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file_a.id},
|
||||
)
|
||||
document_b = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
data_source_info={"upload_file_id": upload_file_b.id},
|
||||
)
|
||||
|
||||
mapping = DocumentService._get_upload_files_by_document_id_for_zip_download(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[document_a.id, document_b.id],
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
assert mapping[document_a.id].id == upload_file_a.id
|
||||
assert mapping[document_b.id].id == upload_file_b.id
|
||||
|
||||
|
||||
def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset(
|
||||
current_user_mock, flask_app_with_containers
|
||||
):
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=str(uuid4()),
|
||||
document_ids=[str(uuid4())],
|
||||
tenant_id=current_user_mock.current_tenant_id,
|
||||
current_user=current_user_mock,
|
||||
)
|
||||
|
||||
|
||||
def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden(
|
||||
db_session_with_containers,
|
||||
current_user_mock,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=current_user_mock.current_tenant_id,
|
||||
created_by=current_user_mock.id,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"services.dataset_service.DatasetService.check_dataset_permission",
|
||||
side_effect=NoPermissionError("denied"),
|
||||
):
|
||||
with pytest.raises(Forbidden, match="denied"):
|
||||
DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[],
|
||||
tenant_id=current_user_mock.current_tenant_id,
|
||||
current_user=current_user_mock,
|
||||
)
|
||||
|
||||
|
||||
def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order(
|
||||
db_session_with_containers,
|
||||
current_user_mock,
|
||||
):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=current_user_mock.current_tenant_id,
|
||||
created_by=current_user_mock.id,
|
||||
)
|
||||
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="a.txt",
|
||||
)
|
||||
upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="b.txt",
|
||||
)
|
||||
document_a = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file_a.id},
|
||||
)
|
||||
document_b = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
data_source_info={"upload_file_id": upload_file_b.id},
|
||||
)
|
||||
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=dataset.id,
|
||||
document_ids=[document_b.id, document_a.id],
|
||||
tenant_id=current_user_mock.current_tenant_id,
|
||||
current_user=current_user_mock,
|
||||
)
|
||||
|
||||
assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id]
|
||||
assert download_name.endswith(".zip")
|
||||
|
||||
|
||||
def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
enabled_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=True,
|
||||
)
|
||||
DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
enabled=False,
|
||||
)
|
||||
|
||||
result = DocumentService.get_document_by_dataset_id(dataset.id)
|
||||
|
||||
assert [document.id for document in result] == [enabled_document.id]
|
||||
|
||||
|
||||
def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
available_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
|
||||
result = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
|
||||
assert [document.id for document in result] == [available_document.id]
|
||||
|
||||
|
||||
def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
error_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
indexing_status=IndexingStatus.ERROR,
|
||||
)
|
||||
paused_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
indexing_status=IndexingStatus.PAUSED,
|
||||
)
|
||||
DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=3,
|
||||
indexing_status=IndexingStatus.COMPLETED,
|
||||
)
|
||||
|
||||
result = DocumentService.get_error_documents_by_dataset_id(dataset.id)
|
||||
|
||||
assert {document.id for document in result} == {error_document.id, paused_document.id}
|
||||
|
||||
|
||||
def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
batch = f"batch-{uuid4()}"
|
||||
matching_document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
batch=batch,
|
||||
)
|
||||
DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
tenant_id=str(uuid4()),
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
|
||||
current_user.current_tenant_id = dataset.tenant_id
|
||||
result = DocumentService.get_batch_documents(dataset.id, batch)
|
||||
|
||||
assert [document.id for document in result] == [matching_document.id]
|
||||
|
||||
|
||||
def test_get_document_file_detail_returns_upload_file(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
)
|
||||
|
||||
result = DocumentService.get_document_file_detail(upload_file.id)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == upload_file.id
|
||||
|
||||
|
||||
def test_delete_document_emits_signal_and_commits(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
upload_file = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
)
|
||||
document = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file.id},
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.document_was_deleted.send") as signal_send:
|
||||
DocumentService.delete_document(document)
|
||||
|
||||
assert db_session_with_containers.get(Document, document.id) is None
|
||||
signal_send.assert_called_once_with(
|
||||
document.id,
|
||||
dataset_id=document.dataset_id,
|
||||
doc_form=document.doc_form,
|
||||
file_id=upload_file.id,
|
||||
)
|
||||
|
||||
|
||||
def test_delete_documents_ignores_empty_input(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
|
||||
DocumentService.delete_documents(dataset, [])
|
||||
|
||||
delay.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX
|
||||
db_session_with_containers.commit()
|
||||
upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="a.txt",
|
||||
)
|
||||
upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=dataset.tenant_id,
|
||||
created_by=dataset.created_by,
|
||||
name="b.txt",
|
||||
)
|
||||
document_a = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
data_source_info={"upload_file_id": upload_file_a.id},
|
||||
)
|
||||
document_b = DocumentServiceIntegrationFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
position=2,
|
||||
data_source_info={"upload_file_id": upload_file_b.id},
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
|
||||
DocumentService.delete_documents(dataset, [document_a.id, document_b.id])
|
||||
|
||||
assert db_session_with_containers.get(Document, document_a.id) is None
|
||||
assert db_session_with_containers.get(Document, document_b.id) is None
|
||||
delay.assert_called_once()
|
||||
args = delay.call_args.args
|
||||
assert args[0] == [document_a.id, document_b.id]
|
||||
assert args[1] == dataset.id
|
||||
assert set(args[3]) == {upload_file_a.id, upload_file_b.id}
|
||||
|
||||
|
||||
def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3)
|
||||
|
||||
assert DocumentService.get_documents_position(dataset.id) == 4
|
||||
|
||||
|
||||
def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers):
|
||||
dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
|
||||
|
||||
assert DocumentService.get_documents_position(dataset.id) == 1
|
||||
@ -605,9 +605,9 @@ def test_schedule_trigger_creates_trigger_log(
|
||||
)
|
||||
|
||||
# Mock quota to avoid rate limiting
|
||||
from enums import quota_type
|
||||
from services import quota_service
|
||||
|
||||
monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited())
|
||||
monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited())
|
||||
|
||||
# Execute schedule trigger
|
||||
workflow_schedule_tasks.run_schedule_trigger(plan.id)
|
||||
|
||||
@ -41,22 +41,17 @@ class TestTenantUserPayload:
|
||||
class TestGetUser:
|
||||
"""Test get_user function"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.select")
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_user_by_id(
|
||||
self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
|
||||
):
|
||||
def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
|
||||
"""Test returning existing user when found by ID"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user123"
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_user
|
||||
mock_query = MagicMock()
|
||||
mock_select.return_value.where.return_value.limit.return_value = mock_query
|
||||
mock_session.get.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
@ -64,45 +59,13 @@ class TestGetUser:
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
mock_session.scalar.assert_called_once()
|
||||
mock_session.get.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.select")
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_not_resolve_non_anonymous_users_across_tenants(
|
||||
self,
|
||||
mock_db,
|
||||
mock_sessionmaker,
|
||||
mock_enduser_class,
|
||||
mock_select,
|
||||
app: Flask,
|
||||
):
|
||||
"""Test that explicit user IDs remain scoped to the current tenant."""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
mock_new_user = MagicMock()
|
||||
mock_new_user.tenant_id = "tenant-current"
|
||||
mock_enduser_class.return_value = mock_new_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant-current", "foreign-user-id")
|
||||
|
||||
# Assert
|
||||
assert result == mock_new_user
|
||||
mock_session.get.assert_not_called()
|
||||
mock_session.scalar.assert_called_once()
|
||||
mock_session.add.assert_called_once_with(mock_new_user)
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.select")
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_anonymous_user_by_session_id(
|
||||
self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
|
||||
self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test returning existing anonymous user by session_id"""
|
||||
# Arrange
|
||||
@ -110,9 +73,8 @@ class TestGetUser:
|
||||
mock_user.session_id = "anonymous_session"
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_user
|
||||
mock_query = MagicMock()
|
||||
mock_select.return_value.where.return_value.limit.return_value = mock_query
|
||||
# non-anonymous path uses session.get(); anonymous uses session.scalar()
|
||||
mock_session.get.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
@ -121,22 +83,17 @@ class TestGetUser:
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.select")
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_create_new_user_when_not_found(
|
||||
self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
|
||||
):
|
||||
def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
|
||||
"""Test creating new user when not found in database"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
mock_session.get.return_value = None
|
||||
mock_new_user = MagicMock()
|
||||
mock_enduser_class.return_value = mock_new_user
|
||||
mock_query = MagicMock()
|
||||
mock_select.return_value.where.return_value.limit.return_value = mock_query
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
@ -177,7 +134,7 @@ class TestGetUser:
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.side_effect = Exception("Database error")
|
||||
mock_session.get.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with app.app_context():
|
||||
|
||||
@ -345,26 +345,22 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
|
||||
)
|
||||
]
|
||||
)
|
||||
mock_session = Mock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
|
||||
encrypted_config="encrypted-old-key"
|
||||
)
|
||||
session = Mock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key")
|
||||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"}
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc::{value}",
|
||||
):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"},
|
||||
credential_id="credential-1",
|
||||
)
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.encrypter.encrypt_token",
|
||||
side_effect=lambda tenant_id, value: f"enc::{value}",
|
||||
):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"},
|
||||
credential_id="credential-1",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert validated["openai_api_key"] == "enc::restored-key"
|
||||
assert validated["region"] == "us"
|
||||
@ -374,15 +370,23 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_validate_provider_credentials_without_credential_id() -> None:
|
||||
def test_validate_provider_credentials_opens_session_when_not_passed() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
mock_session = Mock()
|
||||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"region": "us"}
|
||||
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
|
||||
with patch("core.entities.provider_configuration.Session") as mock_session_cls:
|
||||
with patch("core.entities.provider_configuration.db") as mock_db:
|
||||
mock_db.engine = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
|
||||
|
||||
assert validated == {"region": "us"}
|
||||
mock_session_cls.assert_called_once()
|
||||
|
||||
|
||||
def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None:
|
||||
@ -713,22 +717,18 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non
|
||||
def test_validate_provider_credentials_handles_invalid_original_json() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
|
||||
mock_session = Mock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
|
||||
encrypted_config="{invalid-json"
|
||||
)
|
||||
session = Mock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json")
|
||||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"}
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
)
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert validated == {"openai_api_key": "enc-key"}
|
||||
|
||||
@ -1060,35 +1060,37 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback(
|
||||
def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
configuration.provider.model_credential_schema = _build_secret_model_schema()
|
||||
mock_session = Mock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
|
||||
session = Mock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
|
||||
encrypted_config='{"openai_api_key":"enc"}'
|
||||
)
|
||||
mock_factory = Mock()
|
||||
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
session=session,
|
||||
)
|
||||
assert validated == {"openai_api_key": "enc-new"}
|
||||
|
||||
session = Mock()
|
||||
mock_factory = Mock()
|
||||
mock_factory.model_credentials_validate.return_value = {"region": "us"}
|
||||
with _patched_session(session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
)
|
||||
assert validated == {"openai_api_key": "enc-new"}
|
||||
|
||||
mock_factory2 = Mock()
|
||||
mock_factory2.model_credentials_validate.return_value = {"region": "us"}
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"region": "us"},
|
||||
)
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"region": "us"},
|
||||
)
|
||||
assert validated == {"region": "us"}
|
||||
|
||||
|
||||
@ -1568,20 +1570,18 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None:
|
||||
def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
|
||||
mock_session = Mock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
session = Mock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
mock_factory = Mock()
|
||||
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"}
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
)
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_provider_credentials(
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert validated == {"openai_api_key": "enc-new"}
|
||||
|
||||
@ -1692,24 +1692,20 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None:
|
||||
def test_validate_custom_model_credentials_handles_invalid_original_json() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
configuration.provider.model_credential_schema = _build_secret_model_schema()
|
||||
mock_session = Mock()
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
|
||||
encrypted_config="{invalid-json"
|
||||
)
|
||||
session = Mock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json")
|
||||
mock_factory = Mock()
|
||||
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
|
||||
|
||||
with _patched_session(mock_session):
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
|
||||
):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
)
|
||||
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
|
||||
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
|
||||
validated = configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"openai_api_key": HIDDEN_VALUE},
|
||||
credential_id="cred-1",
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert validated == {"openai_api_key": "enc-new"}
|
||||
|
||||
|
||||
@ -5,7 +5,10 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from dify_trace_aliyun.data_exporter.traceclient import (
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
INVALID_SPAN_ID,
|
||||
SpanBuilder,
|
||||
TraceClient,
|
||||
@ -17,9 +20,7 @@ from dify_trace_aliyun.data_exporter.traceclient import (
|
||||
create_link,
|
||||
generate_span_id,
|
||||
)
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -40,8 +41,8 @@ def trace_client_factory():
|
||||
|
||||
|
||||
class TestTraceClient:
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname")
|
||||
def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory):
|
||||
mock_gethostname.return_value = "test-host"
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
@ -55,7 +56,7 @@ class TestTraceClient:
|
||||
client.shutdown()
|
||||
assert client.done is True
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
@ -63,8 +64,8 @@ class TestTraceClient:
|
||||
client.export(spans)
|
||||
mock_exporter.export.assert_called_once_with(spans)
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 405
|
||||
@ -73,8 +74,8 @@ class TestTraceClient:
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.api_check() is True
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@ -83,8 +84,8 @@ class TestTraceClient:
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.api_check() is False
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_head.side_effect = httpx.RequestError("Connection error")
|
||||
|
||||
@ -92,12 +93,12 @@ class TestTraceClient:
|
||||
with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"):
|
||||
client.api_check()
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_get_project_url(self, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_add_span(self, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(
|
||||
service_name="test-service",
|
||||
@ -133,8 +134,8 @@ class TestTraceClient:
|
||||
assert len(client.queue) == 2
|
||||
mock_notify.assert_called_once()
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.logger")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.logger")
|
||||
def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
|
||||
|
||||
@ -158,7 +159,7 @@ class TestTraceClient:
|
||||
assert len(client.queue) == 1
|
||||
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export_batch_error(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
mock_exporter.export.side_effect = Exception("Export failed")
|
||||
@ -167,11 +168,11 @@ class TestTraceClient:
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.queue.append(mock_span)
|
||||
|
||||
with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger:
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger:
|
||||
client._export_batch()
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_worker_loop(self, mock_exporter_class, trace_client_factory):
|
||||
# We need to test the wait timeout in _worker
|
||||
# But _worker runs in a thread. Let's mock condition.wait.
|
||||
@ -188,7 +189,7 @@ class TestTraceClient:
|
||||
# mock_wait might have been called
|
||||
assert mock_wait.called or client.done
|
||||
|
||||
@patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
@ -267,7 +268,7 @@ def test_generate_span_id():
|
||||
assert span_id != INVALID_SPAN_ID
|
||||
|
||||
# Test retry loop
|
||||
with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand:
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand:
|
||||
mock_rand.side_effect = [INVALID_SPAN_ID, 999]
|
||||
span_id = generate_span_id()
|
||||
assert span_id == 999
|
||||
@ -289,7 +290,7 @@ def test_convert_to_trace_id():
|
||||
def test_convert_string_to_id():
|
||||
assert convert_string_to_id("test") > 0
|
||||
# Test with None string
|
||||
with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen:
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen:
|
||||
mock_gen.return_value = 12345
|
||||
assert convert_string_to_id(None) == 12345
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import pytest
|
||||
from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
|
||||
|
||||
class TestTraceMetadata:
|
||||
def test_trace_metadata_init(self):
|
||||
@ -1,4 +1,4 @@
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
ACS_ARMS_SERVICE_FEATURE,
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
@ -4,11 +4,12 @@ from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
|
||||
import pytest
|
||||
from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
|
||||
from dify_trace_aliyun.config import AliyunConfig
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
|
||||
|
||||
import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module
|
||||
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
@ -23,8 +24,7 @@ from dify_trace_aliyun.entities.semconv import (
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
|
||||
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@ -1,7 +1,9 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
from opentelemetry.trace import Link, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
@ -9,7 +11,7 @@ from dify_trace_aliyun.entities.semconv import (
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
)
|
||||
from dify_trace_aliyun.utils import (
|
||||
from core.ops.aliyun_trace.utils import (
|
||||
create_common_span_attributes,
|
||||
create_links_from_trace_id,
|
||||
create_status_from_error,
|
||||
@ -21,8 +23,6 @@ from dify_trace_aliyun.utils import (
|
||||
get_workflow_node_status,
|
||||
serialize_json_data,
|
||||
)
|
||||
from opentelemetry.trace import Link, StatusCode
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
@ -48,7 +48,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = end_user_data
|
||||
|
||||
from dify_trace_aliyun.utils import db
|
||||
from core.ops.aliyun_trace.utils import db
|
||||
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
|
||||
@ -63,7 +63,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
|
||||
from dify_trace_aliyun.utils import db
|
||||
from core.ops.aliyun_trace.utils import db
|
||||
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
|
||||
@ -112,9 +112,9 @@ def test_get_workflow_node_status():
|
||||
def test_create_links_from_trace_id(monkeypatch):
|
||||
# Mock create_link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
import dify_trace_aliyun.data_exporter.traceclient
|
||||
import core.ops.aliyun_trace.data_exporter.traceclient
|
||||
|
||||
monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
|
||||
monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
|
||||
|
||||
# Trace ID None
|
||||
assert create_links_from_trace_id(None) == []
|
||||
@ -2,7 +2,11 @@ from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import (
|
||||
from opentelemetry.sdk.trace import Tracer
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.trace import StatusCode
|
||||
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
|
||||
ArizePhoenixDataTrace,
|
||||
datetime_to_nanos,
|
||||
error_to_string,
|
||||
@ -11,11 +15,7 @@ from dify_trace_arize_phoenix.arize_phoenix_trace import (
|
||||
setup_tracer,
|
||||
wrap_span_metadata,
|
||||
)
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
|
||||
from opentelemetry.sdk.trace import Tracer
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.trace import StatusCode
|
||||
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@ -80,7 +80,7 @@ def test_datetime_to_nanos():
|
||||
expected = int(dt.timestamp() * 1_000_000_000)
|
||||
assert datetime_to_nanos(dt) == expected
|
||||
|
||||
with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt:
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt:
|
||||
mock_now = MagicMock()
|
||||
mock_now.timestamp.return_value = 1704110400.0
|
||||
mock_dt.now.return_value = mock_now
|
||||
@ -142,8 +142,8 @@ def test_wrap_span_metadata():
|
||||
assert res == {"a": 1, "b": 2, "created_from": "Dify"}
|
||||
|
||||
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
def test_setup_tracer_arize(mock_provider, mock_exporter):
|
||||
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
|
||||
setup_tracer(config)
|
||||
@ -151,8 +151,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter):
|
||||
assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1"
|
||||
|
||||
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
def test_setup_tracer_phoenix(mock_provider, mock_exporter):
|
||||
config = PhoenixConfig(endpoint="http://p.com", project="p")
|
||||
setup_tracer(config)
|
||||
@ -162,7 +162,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter):
|
||||
|
||||
def test_setup_tracer_exception():
|
||||
config = ArizeConfig(endpoint="http://a.com", project="p")
|
||||
with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
|
||||
with pytest.raises(Exception, match="boom"):
|
||||
setup_tracer(config)
|
||||
|
||||
@ -172,7 +172,7 @@ def test_setup_tracer_exception():
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance():
|
||||
with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup:
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup:
|
||||
mock_tracer = MagicMock(spec=Tracer)
|
||||
mock_processor = MagicMock()
|
||||
mock_setup.return_value = (mock_tracer, mock_processor)
|
||||
@ -228,9 +228,9 @@ def test_trace_exception(trace_instance):
|
||||
trace_instance.trace(_make_workflow_info())
|
||||
|
||||
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory")
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_workflow_info()
|
||||
@ -262,7 +262,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac
|
||||
assert trace_instance.tracer.start_span.call_count >= 2
|
||||
|
||||
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_workflow_trace_no_app_id(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_workflow_info()
|
||||
@ -271,7 +271,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance):
|
||||
trace_instance.workflow_trace(info)
|
||||
|
||||
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_message_trace_success(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_message_info()
|
||||
@ -291,7 +291,7 @@ def test_message_trace_success(mock_db, trace_instance):
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_message_trace_with_error(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_message_info()
|
||||
@ -5,16 +5,8 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.entities.langfuse_trace_entity import (
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
LangfuseTrace,
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
|
||||
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@ -25,6 +17,14 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
LangfuseTrace,
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models import EndUser
|
||||
from models.enums import MessageStatus
|
||||
@ -43,7 +43,7 @@ def langfuse_config():
|
||||
def trace_instance(langfuse_config, monkeypatch):
|
||||
# Mock Langfuse client to avoid network calls
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
|
||||
|
||||
instance = LangFuseDataTrace(langfuse_config)
|
||||
return instance
|
||||
@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch):
|
||||
|
||||
def test_init(langfuse_config, monkeypatch):
|
||||
mock_langfuse = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = LangFuseDataTrace(langfuse_config)
|
||||
@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
|
||||
# Mock DB and Repositories
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
# Mock node executions
|
||||
node_llm = MagicMock()
|
||||
@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
error="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_execution.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
)
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat
|
||||
repo.get_by_workflow_execution.return_value = [node]
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
@ -3,14 +3,8 @@ from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dify_trace_langsmith.config import LangSmithConfig
|
||||
from dify_trace_langsmith.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
|
||||
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@ -21,6 +15,12 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser
|
||||
|
||||
@ -38,7 +38,7 @@ def langsmith_config():
|
||||
def trace_instance(langsmith_config, monkeypatch):
|
||||
# Mock LangSmith client
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client)
|
||||
|
||||
instance = LangSmithDataTrace(langsmith_config)
|
||||
return instance
|
||||
@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch):
|
||||
|
||||
def test_init(langsmith_config, monkeypatch):
|
||||
mock_client_class = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = LangSmithDataTrace(langsmith_config)
|
||||
@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch):
|
||||
|
||||
# Mock dependencies
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
# Mock node executions
|
||||
node_llm = MagicMock()
|
||||
@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch):
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_execution.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_info.error = ""
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch):
|
||||
# Mock EndUser lookup
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
@ -1,4 +1,4 @@
|
||||
"""Comprehensive tests for dify_trace_mlflow.mlflow_trace module."""
|
||||
"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -9,9 +9,8 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
|
||||
from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds
|
||||
|
||||
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@ -21,6 +20,7 @@ from core.ops.entities.trace_entity import (
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
@ -179,7 +179,7 @@ def _make_node(**overrides):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mlflow():
|
||||
with patch("dify_trace_mlflow.mlflow_trace.mlflow") as mock:
|
||||
with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@ -187,10 +187,10 @@ def mock_mlflow():
|
||||
def mock_tracing():
|
||||
"""Patch all MLflow tracing functions used by the module."""
|
||||
with (
|
||||
patch("dify_trace_mlflow.mlflow_trace.start_span_no_context") as mock_start,
|
||||
patch("dify_trace_mlflow.mlflow_trace.update_current_trace") as mock_update,
|
||||
patch("dify_trace_mlflow.mlflow_trace.set_span_in_context") as mock_set,
|
||||
patch("dify_trace_mlflow.mlflow_trace.detach_span_from_context") as mock_detach,
|
||||
patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start,
|
||||
patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update,
|
||||
patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set,
|
||||
patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach,
|
||||
):
|
||||
yield {
|
||||
"start": mock_start,
|
||||
@ -202,7 +202,7 @@ def mock_tracing():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
with patch("dify_trace_mlflow.mlflow_trace.db") as mock:
|
||||
with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@ -5,9 +5,8 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dify_trace_opik.config import OpikConfig
|
||||
from dify_trace_opik.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
|
||||
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@ -18,6 +17,7 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser
|
||||
from models.enums import MessageStatus
|
||||
@ -37,7 +37,7 @@ def opik_config():
|
||||
@pytest.fixture
|
||||
def trace_instance(opik_config, monkeypatch):
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client)
|
||||
|
||||
instance = OpikDataTrace(opik_config)
|
||||
return instance
|
||||
@ -67,7 +67,7 @@ def test_prepare_opik_uuid():
|
||||
|
||||
def test_init(opik_config, monkeypatch):
|
||||
mock_opik = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = OpikDataTrace(opik_config)
|
||||
@ -166,8 +166,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
node_llm = MagicMock()
|
||||
node_llm.id = LLM_NODE_ID
|
||||
@ -203,7 +203,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
@ -250,13 +250,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
error="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_execution.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
@ -286,8 +286,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e",
|
||||
error="",
|
||||
)
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
@ -373,7 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user)
|
||||
|
||||
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2"))
|
||||
trace_instance.add_span = MagicMock()
|
||||
@ -658,9 +658,9 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch
|
||||
repo.get_by_workflow_execution.return_value = [node]
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
@ -8,12 +8,13 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dify_trace_tencent import client as client_module
|
||||
from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
|
||||
from dify_trace_tencent.entities.tencent_trace_entity import SpanData
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from core.ops.tencent_trace import client as client_module
|
||||
from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version
|
||||
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||
|
||||
metric_reader_instances: list[DummyMetricReader] = []
|
||||
meter_provider_instances: list[DummyMeterProvider] = []
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user