mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge branch 'main' into feat/rag-2
# Conflicts: # web/app/components/workflow/hooks/use-workflow.ts
This commit is contained in:
@ -1,12 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic import Field, PositiveInt, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ElasticsearchConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Elasticsearch
|
||||
Configuration settings for both self-managed and Elastic Cloud deployments.
|
||||
Can load from environment variables or .env files.
|
||||
"""
|
||||
|
||||
ELASTICSEARCH_HOST: Optional[str] = Field(
|
||||
@ -28,3 +29,50 @@ class ElasticsearchConfig(BaseSettings):
|
||||
description="Password for authenticating with Elasticsearch (default is 'elastic')",
|
||||
default="elastic",
|
||||
)
|
||||
|
||||
# Elastic Cloud (optional)
|
||||
ELASTICSEARCH_USE_CLOUD: Optional[bool] = Field(
|
||||
description="Set to True to use Elastic Cloud instead of self-hosted Elasticsearch", default=False
|
||||
)
|
||||
ELASTICSEARCH_CLOUD_URL: Optional[str] = Field(
|
||||
description="Full URL for Elastic Cloud deployment (e.g., 'https://example.es.region.aws.found.io:443')",
|
||||
default=None,
|
||||
)
|
||||
ELASTICSEARCH_API_KEY: Optional[str] = Field(
|
||||
description="API key for authenticating with Elastic Cloud", default=None
|
||||
)
|
||||
|
||||
# Common options
|
||||
ELASTICSEARCH_CA_CERTS: Optional[str] = Field(
|
||||
description="Path to CA certificate file for SSL verification", default=None
|
||||
)
|
||||
ELASTICSEARCH_VERIFY_CERTS: bool = Field(
|
||||
description="Whether to verify SSL certificates (default is False)", default=False
|
||||
)
|
||||
ELASTICSEARCH_REQUEST_TIMEOUT: int = Field(
|
||||
description="Request timeout in milliseconds (default is 100000)", default=100000
|
||||
)
|
||||
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = Field(
|
||||
description="Whether to retry requests on timeout (default is True)", default=True
|
||||
)
|
||||
ELASTICSEARCH_MAX_RETRIES: int = Field(
|
||||
description="Maximum number of retry attempts (default is 10000)", default=10000
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_elasticsearch_config(self):
|
||||
"""Validate Elasticsearch configuration based on deployment type."""
|
||||
if self.ELASTICSEARCH_USE_CLOUD:
|
||||
if not self.ELASTICSEARCH_CLOUD_URL:
|
||||
raise ValueError("ELASTICSEARCH_CLOUD_URL is required when using Elastic Cloud")
|
||||
if not self.ELASTICSEARCH_API_KEY:
|
||||
raise ValueError("ELASTICSEARCH_API_KEY is required when using Elastic Cloud")
|
||||
else:
|
||||
if not self.ELASTICSEARCH_HOST:
|
||||
raise ValueError("ELASTICSEARCH_HOST is required for self-hosted Elasticsearch")
|
||||
if not self.ELASTICSEARCH_USERNAME:
|
||||
raise ValueError("ELASTICSEARCH_USERNAME is required for self-hosted Elasticsearch")
|
||||
if not self.ELASTICSEARCH_PASSWORD:
|
||||
raise ValueError("ELASTICSEARCH_PASSWORD is required for self-hosted Elasticsearch")
|
||||
|
||||
return self
|
||||
|
||||
@ -100,7 +100,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
class AnnotationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -123,6 +123,23 @@ class AnnotationListApi(Resource):
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -131,8 +148,25 @@ class AnnotationListApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
AppAnnotationService.clear_all_annotations(app_id)
|
||||
return {"result": "success"}, 204
|
||||
|
||||
# Use request.args.getlist to get annotation_ids array directly
|
||||
annotation_ids = request.args.getlist("annotation_id")
|
||||
|
||||
# If annotation_ids are provided, handle batch deletion
|
||||
if annotation_ids:
|
||||
# Check if any annotation_ids contain empty strings or invalid values
|
||||
if not all(annotation_id.strip() for annotation_id in annotation_ids if annotation_id):
|
||||
return {
|
||||
"code": "bad_request",
|
||||
"message": "annotation_ids are required if the parameter is provided.",
|
||||
}, 400
|
||||
|
||||
result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids)
|
||||
return result, 204
|
||||
# If no annotation_ids are provided, handle clearing all annotations
|
||||
else:
|
||||
AppAnnotationService.clear_all_annotations(app_id)
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class AnnotationExportApi(Resource):
|
||||
@ -149,25 +183,6 @@ class AnnotationExportApi(Resource):
|
||||
return response, 200
|
||||
|
||||
|
||||
class AnnotationCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
|
||||
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -276,7 +291,7 @@ api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply
|
||||
api.add_resource(
|
||||
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
|
||||
)
|
||||
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
|
||||
api.add_resource(AnnotationApi, "/apps/<uuid:app_id>/annotations")
|
||||
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
|
||||
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
|
||||
|
||||
@ -22,8 +22,8 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def post(self, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
|
||||
@ -56,7 +56,7 @@ class DatasetMetadataApi(Resource):
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -127,7 +127,7 @@ class DocumentMetadataEditApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
|
||||
|
||||
@ -47,6 +47,9 @@ class CompletionApi(Resource):
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
|
||||
|
||||
@ -17,8 +17,8 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
|
||||
@ -43,7 +43,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -101,7 +101,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
|
||||
|
||||
@ -159,6 +159,8 @@ SupportedComparisonOperator = Literal[
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_string_to_id,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
generate_span_id,
|
||||
@ -101,8 +102,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
raise ValueError(f"Aliyun get run url failed: {str(e)}")
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
external_trace_id = trace_info.metadata.get("external_trace_id")
|
||||
trace_id = external_trace_id or convert_to_trace_id(trace_info.workflow_run_id)
|
||||
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
|
||||
|
||||
@ -130,6 +132,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
|
||||
message_span_id = convert_to_span_id(message_id, "message")
|
||||
message_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
@ -186,9 +191,13 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
return
|
||||
message_id = trace_info.message_id
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
|
||||
documents_data = extract_retrieval_documents(trace_info.documents)
|
||||
dataset_retrieval_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
trace_id=trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name="dataset_retrieval",
|
||||
@ -214,8 +223,12 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
|
||||
tool_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
trace_id=trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name=trace_info.tool_name,
|
||||
@ -451,8 +464,13 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
|
||||
suggested_question_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
trace_id=trace_id,
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=convert_to_span_id(message_id, "suggested_question"),
|
||||
name="suggested_question",
|
||||
|
||||
@ -181,15 +181,21 @@ def convert_to_trace_id(uuid_v4: Optional[str]) -> int:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
|
||||
|
||||
def convert_string_to_id(string: Optional[str]) -> int:
|
||||
if not string:
|
||||
return generate_span_id()
|
||||
hash_bytes = hashlib.sha256(string.encode("utf-8")).digest()
|
||||
id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
return id
|
||||
|
||||
|
||||
def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
|
||||
span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
return span_id
|
||||
return convert_string_to_id(combined_key)
|
||||
|
||||
|
||||
def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]:
|
||||
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
@ -40,8 +41,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
|
||||
try:
|
||||
# Choose the appropriate exporter based on config type
|
||||
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
|
||||
|
||||
# Inspect the provided endpoint to determine its structure
|
||||
parsed = urlparse(arize_phoenix_config.endpoint)
|
||||
base_endpoint = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
if isinstance(arize_phoenix_config, ArizeConfig):
|
||||
arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
|
||||
arize_endpoint = f"{base_endpoint}/v1"
|
||||
arize_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"space_id": arize_phoenix_config.space_id or "",
|
||||
@ -53,7 +60,7 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
|
||||
phoenix_endpoint = f"{base_endpoint}{path}/v1/traces"
|
||||
phoenix_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
@ -91,16 +98,21 @@ def datetime_to_nanos(dt: Optional[datetime]) -> int:
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def uuid_to_trace_id(string: Optional[str]) -> int:
|
||||
"""Convert UUID string to a valid trace ID (16-byte integer)."""
|
||||
def string_to_trace_id128(string: Optional[str]) -> int:
|
||||
"""
|
||||
Convert any input string into a stable 128-bit integer trace ID.
|
||||
|
||||
This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest.
|
||||
It's suitable for generating consistent, unique identifiers from strings.
|
||||
"""
|
||||
if string is None:
|
||||
string = ""
|
||||
hash_object = hashlib.sha256(string.encode())
|
||||
|
||||
# Take the first 16 bytes (128 bits) of the hash
|
||||
# Take the first 16 bytes (128 bits) of the hash digest
|
||||
digest = hash_object.digest()[:16]
|
||||
|
||||
# Convert to integer (128 bits)
|
||||
# Convert to a 128-bit integer
|
||||
return int.from_bytes(digest, byteorder="big")
|
||||
|
||||
|
||||
@ -153,8 +165,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
external_trace_id = trace_info.metadata.get("external_trace_id")
|
||||
trace_id = external_trace_id or uuid_to_trace_id(trace_info.workflow_run_id)
|
||||
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
@ -310,7 +321,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id)
|
||||
message_span_id = RandomIdGenerator().generate_span_id()
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
@ -406,7 +417,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
@ -468,7 +479,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
@ -521,7 +532,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
@ -568,7 +579,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
tool_span_id = RandomIdGenerator().generate_span_id()
|
||||
logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id)
|
||||
|
||||
@ -629,7 +640,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
|
||||
@ -87,7 +87,7 @@ class PhoenixConfig(BaseTracingConfig):
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com")
|
||||
return validate_url_with_path(v, "https://app.phoenix.arize.com")
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
|
||||
@ -14,6 +14,7 @@ class BaseTraceInfo(BaseModel):
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
metadata: dict[str, Any]
|
||||
trace_id: Optional[str] = None
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
|
||||
@ -67,14 +67,13 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
external_trace_id = trace_info.metadata.get("external_trace_id")
|
||||
trace_id = external_trace_id or trace_info.workflow_run_id
|
||||
trace_id = trace_info.trace_id or trace_info.workflow_run_id
|
||||
user_id = trace_info.metadata.get("user_id")
|
||||
metadata = trace_info.metadata
|
||||
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
if trace_info.message_id:
|
||||
trace_id = external_trace_id or trace_info.message_id
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
name = TraceTaskName.MESSAGE_TRACE.value
|
||||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
@ -250,8 +249,10 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
user_id = end_user_data.session_id
|
||||
metadata["user_id"] = user_id
|
||||
|
||||
trace_id = trace_info.trace_id or message_id
|
||||
|
||||
trace_data = LangfuseTrace(
|
||||
id=message_id,
|
||||
id=trace_id,
|
||||
user_id=user_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
input={
|
||||
@ -285,7 +286,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
langfuse_generation_data = LangfuseGeneration(
|
||||
name="llm",
|
||||
trace_id=message_id,
|
||||
trace_id=trace_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
model=message_data.model_id,
|
||||
@ -311,7 +312,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
trace_id=trace_info.message_id,
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.created_at,
|
||||
metadata=trace_info.metadata,
|
||||
@ -334,7 +335,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
input=trace_info.inputs,
|
||||
output=str(trace_info.suggested_question),
|
||||
trace_id=trace_info.message_id,
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
@ -352,7 +353,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
input=trace_info.inputs,
|
||||
output={"documents": trace_info.documents},
|
||||
trace_id=trace_info.message_id,
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
start_time=trace_info.start_time or trace_info.message_data.created_at,
|
||||
end_time=trace_info.end_time or trace_info.message_data.updated_at,
|
||||
metadata=trace_info.metadata,
|
||||
@ -365,7 +366,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
name=trace_info.tool_name,
|
||||
input=trace_info.tool_inputs,
|
||||
output=trace_info.tool_outputs,
|
||||
trace_id=trace_info.message_id,
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
|
||||
@ -65,8 +65,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
external_trace_id = trace_info.metadata.get("external_trace_id")
|
||||
trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
message_dotted_order = (
|
||||
@ -290,7 +289,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
parent_run_id=None,
|
||||
)
|
||||
@ -319,7 +318,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
@ -351,7 +350,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
@ -381,7 +380,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
@ -410,7 +409,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
@ -440,7 +439,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
error=trace_info.error or "",
|
||||
)
|
||||
@ -465,7 +464,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
reference_example_id=None,
|
||||
input_attachments={},
|
||||
output_attachments={},
|
||||
trace_id=None,
|
||||
trace_id=trace_info.trace_id,
|
||||
dotted_order=None,
|
||||
error="",
|
||||
file_list=[],
|
||||
|
||||
@ -96,8 +96,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
external_trace_id = trace_info.metadata.get("external_trace_id")
|
||||
dify_trace_id = external_trace_id or trace_info.workflow_run_id
|
||||
dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
workflow_metadata = wrap_metadata(
|
||||
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
|
||||
@ -105,7 +104,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
root_span_id = None
|
||||
|
||||
if trace_info.message_id:
|
||||
dify_trace_id = external_trace_id or trace_info.message_id
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
|
||||
trace_data = {
|
||||
@ -276,7 +275,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
return
|
||||
|
||||
metadata = trace_info.metadata
|
||||
message_id = trace_info.message_id
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
metadata["user_id"] = user_id
|
||||
@ -291,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
metadata["end_user_id"] = end_user_id
|
||||
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, message_id),
|
||||
"id": prepare_opik_uuid(trace_info.start_time, dify_trace_id),
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
@ -330,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.MODERATION_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
@ -356,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
start_time = trace_info.start_time or message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
@ -376,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
@ -391,7 +390,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
||||
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": trace_info.tool_name,
|
||||
"type": "tool",
|
||||
"start_time": trace_info.start_time,
|
||||
@ -406,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
||||
"id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
|
||||
@ -422,8 +422,11 @@ class TraceTask:
|
||||
self.timer = timer
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.app_id = None
|
||||
|
||||
self.trace_id = None
|
||||
self.kwargs = kwargs
|
||||
external_trace_id = kwargs.get("external_trace_id")
|
||||
if external_trace_id:
|
||||
self.trace_id = external_trace_id
|
||||
|
||||
def execute(self):
|
||||
return self.preprocess()
|
||||
@ -520,11 +523,8 @@ class TraceTask:
|
||||
"app_id": workflow_run.app_id,
|
||||
}
|
||||
|
||||
external_trace_id = self.kwargs.get("external_trace_id")
|
||||
if external_trace_id:
|
||||
metadata["external_trace_id"] = external_trace_id
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
conversation_id=conversation_id,
|
||||
workflow_id=workflow_id,
|
||||
@ -584,6 +584,7 @@ class TraceTask:
|
||||
message_tokens = message_data.message_tokens
|
||||
|
||||
message_trace_info = MessageTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
message_data=message_data.to_dict(),
|
||||
conversation_model=conversation_mode,
|
||||
@ -627,6 +628,7 @@ class TraceTask:
|
||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
|
||||
moderation_trace_info = ModerationTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=workflow_app_log_id or message_id,
|
||||
inputs=inputs,
|
||||
message_data=message_data.to_dict(),
|
||||
@ -667,6 +669,7 @@ class TraceTask:
|
||||
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
|
||||
|
||||
suggested_question_trace_info = SuggestedQuestionTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=workflow_app_log_id or message_id,
|
||||
message_data=message_data.to_dict(),
|
||||
inputs=message_data.message,
|
||||
@ -708,6 +711,7 @@ class TraceTask:
|
||||
}
|
||||
|
||||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
inputs=message_data.query or message_data.inputs,
|
||||
documents=[doc.model_dump() for doc in documents] if documents else [],
|
||||
@ -772,6 +776,7 @@ class TraceTask:
|
||||
)
|
||||
|
||||
tool_trace_info = ToolTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
message_data=message_data.to_dict(),
|
||||
tool_name=tool_name,
|
||||
@ -807,6 +812,7 @@ class TraceTask:
|
||||
}
|
||||
|
||||
generate_name_trace_info = GenerateNameTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
conversation_id=conversation_id,
|
||||
inputs=inputs,
|
||||
outputs=generate_conversation_name,
|
||||
|
||||
@ -87,8 +87,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
external_trace_id = trace_info.metadata.get("external_trace_id")
|
||||
trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
|
||||
@ -245,8 +244,12 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
attributes["tags"] = ["message", str(trace_info.conversation_mode)]
|
||||
|
||||
trace_id = trace_info.trace_id or message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
message_run = WeaveTraceModel(
|
||||
id=message_id,
|
||||
id=trace_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE.value),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
@ -274,7 +277,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
)
|
||||
self.start_call(
|
||||
llm_run,
|
||||
parent_run_id=message_id,
|
||||
parent_run_id=trace_id,
|
||||
)
|
||||
self.finish_call(llm_run)
|
||||
self.finish_call(message_run)
|
||||
@ -289,6 +292,9 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
|
||||
attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
moderation_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.MODERATION_TRACE.value),
|
||||
@ -303,7 +309,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
self.start_call(moderation_run, parent_run_id=trace_info.message_id)
|
||||
self.start_call(moderation_run, parent_run_id=trace_id)
|
||||
self.finish_call(moderation_run)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
@ -316,6 +322,9 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
|
||||
attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
suggested_question_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
|
||||
@ -326,7 +335,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(suggested_question_run, parent_run_id=trace_info.message_id)
|
||||
self.start_call(suggested_question_run, parent_run_id=trace_id)
|
||||
self.finish_call(suggested_question_run)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
@ -338,6 +347,9 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
|
||||
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
dataset_retrieval_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
|
||||
@ -348,7 +360,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id)
|
||||
self.start_call(dataset_retrieval_run, parent_run_id=trace_id)
|
||||
self.finish_call(dataset_retrieval_run)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
@ -357,6 +369,11 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
|
||||
message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
|
||||
message_id = message_id or None
|
||||
trace_id = trace_info.trace_id or message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
tool_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=trace_info.tool_name,
|
||||
@ -366,9 +383,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
attributes=attributes,
|
||||
exception=trace_info.error,
|
||||
)
|
||||
message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
|
||||
message_id = message_id or None
|
||||
self.start_call(tool_run, parent_run_id=message_id)
|
||||
self.start_call(tool_run, parent_run_id=trace_id)
|
||||
self.finish_call(tool_run)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
|
||||
@ -22,22 +22,50 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElasticSearchConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
username: str
|
||||
password: str
|
||||
# Regular Elasticsearch config
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
# Elastic Cloud specific config
|
||||
cloud_url: Optional[str] = None # Cloud URL for Elasticsearch Cloud
|
||||
api_key: Optional[str] = None
|
||||
|
||||
# Common config
|
||||
use_cloud: bool = False
|
||||
ca_certs: Optional[str] = None
|
||||
verify_certs: bool = False
|
||||
request_timeout: int = 100000
|
||||
retry_on_timeout: bool = True
|
||||
max_retries: int = 10000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config PORT is required")
|
||||
if not values["username"]:
|
||||
raise ValueError("config USERNAME is required")
|
||||
if not values["password"]:
|
||||
raise ValueError("config PASSWORD is required")
|
||||
use_cloud = values.get("use_cloud", False)
|
||||
cloud_url = values.get("cloud_url")
|
||||
|
||||
if use_cloud:
|
||||
# Cloud configuration validation - requires cloud_url and api_key
|
||||
if not cloud_url:
|
||||
raise ValueError("cloud_url is required for Elastic Cloud")
|
||||
|
||||
api_key = values.get("api_key")
|
||||
if not api_key:
|
||||
raise ValueError("api_key is required for Elastic Cloud")
|
||||
|
||||
else:
|
||||
# Regular Elasticsearch validation
|
||||
if not values.get("host"):
|
||||
raise ValueError("config HOST is required for regular Elasticsearch")
|
||||
if not values.get("port"):
|
||||
raise ValueError("config PORT is required for regular Elasticsearch")
|
||||
if not values.get("username"):
|
||||
raise ValueError("config USERNAME is required for regular Elasticsearch")
|
||||
if not values.get("password"):
|
||||
raise ValueError("config PASSWORD is required for regular Elasticsearch")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@ -50,21 +78,69 @@ class ElasticSearchVector(BaseVector):
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
||||
"""
|
||||
Initialize Elasticsearch client for both regular Elasticsearch and Elastic Cloud.
|
||||
"""
|
||||
try:
|
||||
parsed_url = urlparse(config.host)
|
||||
if parsed_url.scheme in {"http", "https"}:
|
||||
hosts = f"{config.host}:{config.port}"
|
||||
# Check if using Elastic Cloud
|
||||
client_config: dict[str, Any]
|
||||
if config.use_cloud and config.cloud_url:
|
||||
client_config = {
|
||||
"request_timeout": config.request_timeout,
|
||||
"retry_on_timeout": config.retry_on_timeout,
|
||||
"max_retries": config.max_retries,
|
||||
"verify_certs": config.verify_certs,
|
||||
}
|
||||
|
||||
# Parse cloud URL and configure hosts
|
||||
parsed_url = urlparse(config.cloud_url)
|
||||
host = f"{parsed_url.scheme}://{parsed_url.hostname}"
|
||||
if parsed_url.port:
|
||||
host += f":{parsed_url.port}"
|
||||
|
||||
client_config["hosts"] = [host]
|
||||
|
||||
# API key authentication for cloud
|
||||
client_config["api_key"] = config.api_key
|
||||
|
||||
# SSL settings
|
||||
if config.ca_certs:
|
||||
client_config["ca_certs"] = config.ca_certs
|
||||
|
||||
else:
|
||||
hosts = f"http://{config.host}:{config.port}"
|
||||
client = Elasticsearch(
|
||||
hosts=hosts,
|
||||
basic_auth=(config.username, config.password),
|
||||
request_timeout=100000,
|
||||
retry_on_timeout=True,
|
||||
max_retries=10000,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
# Regular Elasticsearch configuration
|
||||
parsed_url = urlparse(config.host or "")
|
||||
if parsed_url.scheme in {"http", "https"}:
|
||||
hosts = f"{config.host}:{config.port}"
|
||||
use_https = parsed_url.scheme == "https"
|
||||
else:
|
||||
hosts = f"http://{config.host}:{config.port}"
|
||||
use_https = False
|
||||
|
||||
client_config = {
|
||||
"hosts": [hosts],
|
||||
"basic_auth": (config.username, config.password),
|
||||
"request_timeout": config.request_timeout,
|
||||
"retry_on_timeout": config.retry_on_timeout,
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
|
||||
# Only add SSL settings if using HTTPS
|
||||
if use_https:
|
||||
client_config["verify_certs"] = config.verify_certs
|
||||
if config.ca_certs:
|
||||
client_config["ca_certs"] = config.ca_certs
|
||||
|
||||
client = Elasticsearch(**client_config)
|
||||
|
||||
# Test connection
|
||||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
|
||||
return client
|
||||
|
||||
@ -209,7 +285,11 @@ class ElasticSearchVector(BaseVector):
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
self._client.indices.create(index=self._collection_name, mappings=mappings)
|
||||
logger.info("Created index %s with dimension %s", self._collection_name, dim)
|
||||
else:
|
||||
logger.info("Collection %s already exists.", self._collection_name)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
@ -225,13 +305,51 @@ class ElasticSearchVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
|
||||
# Check if ELASTICSEARCH_USE_CLOUD is explicitly set to false (boolean)
|
||||
use_cloud_env = config.get("ELASTICSEARCH_USE_CLOUD", False)
|
||||
|
||||
if use_cloud_env is False:
|
||||
# Use regular Elasticsearch with config values
|
||||
config_dict = {
|
||||
"use_cloud": False,
|
||||
"host": config.get("ELASTICSEARCH_HOST", "elasticsearch"),
|
||||
"port": config.get("ELASTICSEARCH_PORT", 9200),
|
||||
"username": config.get("ELASTICSEARCH_USERNAME", "elastic"),
|
||||
"password": config.get("ELASTICSEARCH_PASSWORD", "elastic"),
|
||||
}
|
||||
else:
|
||||
# Check for cloud configuration
|
||||
cloud_url = config.get("ELASTICSEARCH_CLOUD_URL")
|
||||
if cloud_url:
|
||||
config_dict = {
|
||||
"use_cloud": True,
|
||||
"cloud_url": cloud_url,
|
||||
"api_key": config.get("ELASTICSEARCH_API_KEY"),
|
||||
}
|
||||
else:
|
||||
# Fallback to regular Elasticsearch
|
||||
config_dict = {
|
||||
"use_cloud": False,
|
||||
"host": config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||
"port": config.get("ELASTICSEARCH_PORT", 9200),
|
||||
"username": config.get("ELASTICSEARCH_USERNAME", "elastic"),
|
||||
"password": config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||
}
|
||||
|
||||
# Common configuration
|
||||
config_dict.update(
|
||||
{
|
||||
"ca_certs": str(config.get("ELASTICSEARCH_CA_CERTS")) if config.get("ELASTICSEARCH_CA_CERTS") else None,
|
||||
"verify_certs": bool(config.get("ELASTICSEARCH_VERIFY_CERTS", False)),
|
||||
"request_timeout": int(config.get("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
|
||||
"retry_on_timeout": bool(config.get("ELASTICSEARCH_RETRY_ON_TIMEOUT", True)),
|
||||
"max_retries": int(config.get("ELASTICSEARCH_MAX_RETRIES", 10000)),
|
||||
}
|
||||
)
|
||||
|
||||
return ElasticSearchVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||
port=config.get("ELASTICSEARCH_PORT", 9200),
|
||||
username=config.get("ELASTICSEARCH_USERNAME", ""),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||
),
|
||||
config=ElasticSearchConfig(**config_dict),
|
||||
attributes=[],
|
||||
)
|
||||
|
||||
@ -13,6 +13,8 @@ SupportedComparisonOperator = Literal[
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
@ -132,13 +133,15 @@ class NotionExtractor(BaseExtractor):
|
||||
data[property_name] = value
|
||||
row_dict = {k: v for k, v in data.items() if v}
|
||||
row_content = ""
|
||||
for key, value in row_dict.items():
|
||||
for key, value in sorted(row_dict.items(), key=operator.itemgetter(0)):
|
||||
if isinstance(value, dict):
|
||||
value_dict = {k: v for k, v in value.items() if v}
|
||||
value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
|
||||
row_content = row_content + f"{key}:{value_content}\n"
|
||||
else:
|
||||
row_content = row_content + f"{key}:{value}\n"
|
||||
if "url" in result:
|
||||
row_content = row_content + f"Row Page URL:{result.get('url', '')}\n"
|
||||
database_content.append(row_content)
|
||||
|
||||
has_more = response_data.get("has_more", False)
|
||||
|
||||
@ -142,7 +142,7 @@ class WorkflowTool(Tool):
|
||||
if not version:
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(Workflow.app_id == app_id, Workflow.version != "draft")
|
||||
.where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -265,9 +265,9 @@ class Executor:
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = "Authorization"
|
||||
|
||||
if self.auth.config.type == "bearer":
|
||||
if self.auth.config.type == "bearer" and authorization.config.api_key:
|
||||
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
|
||||
elif self.auth.config.type == "basic":
|
||||
elif self.auth.config.type == "basic" and authorization.config.api_key:
|
||||
credentials = authorization.config.api_key
|
||||
if ":" in credentials:
|
||||
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
|
||||
|
||||
@ -74,6 +74,8 @@ SupportedComparisonOperator = Literal[
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
|
||||
@ -602,6 +602,28 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
**{key: metadata_name, key_value: f"%{value}"}
|
||||
)
|
||||
)
|
||||
case "in":
|
||||
if isinstance(value, str):
|
||||
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
|
||||
escaped_value_str = ",".join(escaped_values)
|
||||
else:
|
||||
escaped_value_str = str(value)
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params(
|
||||
**{key: metadata_name, key_value: escaped_value_str}
|
||||
)
|
||||
)
|
||||
case "not in":
|
||||
if isinstance(value, str):
|
||||
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
|
||||
escaped_value_str = ",".join(escaped_values)
|
||||
else:
|
||||
escaped_value_str = str(value)
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params(
|
||||
**{key: metadata_name, key_value: escaped_value_str}
|
||||
)
|
||||
)
|
||||
case "=" | "is":
|
||||
if isinstance(value, str):
|
||||
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import click
|
||||
from kombu.utils.url import parse_url # type: ignore
|
||||
from redis import Redis
|
||||
|
||||
import app
|
||||
@ -10,16 +10,13 @@ from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
# Create a dedicated Redis connection (using the same configuration as Celery)
|
||||
celery_broker_url = dify_config.CELERY_BROKER_URL
|
||||
|
||||
parsed = urlparse(celery_broker_url)
|
||||
host = parsed.hostname or "localhost"
|
||||
port = parsed.port or 6379
|
||||
password = parsed.password or None
|
||||
redis_db = parsed.path.strip("/") or "1" # type: ignore
|
||||
|
||||
celery_redis = Redis(host=host, port=port, password=password, db=redis_db)
|
||||
redis_config = parse_url(dify_config.CELERY_BROKER_URL)
|
||||
celery_redis = Redis(
|
||||
host=redis_config["hostname"],
|
||||
port=redis_config["port"],
|
||||
password=redis_config["password"],
|
||||
db=int(redis_config["virtual_host"]) if redis_config["virtual_host"] else 1,
|
||||
)
|
||||
|
||||
|
||||
@app.celery.task(queue="monitor")
|
||||
|
||||
@ -266,6 +266,54 @@ class AppAnnotationService:
|
||||
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
# Fetch annotations and their settings in a single query
|
||||
annotations_to_delete = (
|
||||
db.session.query(MessageAnnotation, AppAnnotationSetting)
|
||||
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
|
||||
.filter(MessageAnnotation.id.in_(annotation_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
if not annotations_to_delete:
|
||||
return {"deleted_count": 0}
|
||||
|
||||
# Step 1: Extract IDs for bulk operations
|
||||
annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete]
|
||||
|
||||
# Step 2: Bulk delete hit histories in a single query
|
||||
db.session.query(AppAnnotationHitHistory).filter(
|
||||
AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Step 3: Trigger async tasks for search index deletion
|
||||
for annotation, annotation_setting in annotations_to_delete:
|
||||
if annotation_setting:
|
||||
delete_annotation_index_task.delay(
|
||||
annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
# Step 4: Bulk delete annotations in a single query
|
||||
deleted_count = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.id.in_(annotation_ids_to_delete))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
return {"deleted_count": deleted_count}
|
||||
|
||||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
|
||||
# get app info
|
||||
@ -280,7 +328,7 @@ class AppAnnotationService:
|
||||
|
||||
try:
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file)
|
||||
df = pd.read_csv(file, dtype=str)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
content = {"question": row.iloc[0], "answer": row.iloc[1]}
|
||||
@ -452,6 +500,11 @@ class AppAnnotationService:
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
# if annotation reply is enabled, delete annotation index
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
|
||||
annotations_query = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id)
|
||||
for annotation in annotations_query.yield_per(100):
|
||||
annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).filter(
|
||||
@ -460,6 +513,12 @@ class AppAnnotationService:
|
||||
for annotation_hit_history in annotation_hit_histories_query.yield_per(100):
|
||||
db.session.delete(annotation_hit_history)
|
||||
|
||||
# if annotation reply is enabled, delete annotation index
|
||||
if app_annotation_setting:
|
||||
delete_annotation_index_task.delay(
|
||||
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
db.session.delete(annotation)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@ -46,9 +46,9 @@ class ExternalDatasetService:
|
||||
def validate_api_list(cls, api_settings: dict):
|
||||
if not api_settings:
|
||||
raise ValueError("api list is empty")
|
||||
if "endpoint" not in api_settings and not api_settings["endpoint"]:
|
||||
if not api_settings.get("endpoint"):
|
||||
raise ValueError("endpoint is required")
|
||||
if "api_key" not in api_settings and not api_settings["api_key"]:
|
||||
if not api_settings.get("api_key"):
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -185,7 +185,7 @@ class WorkflowConverter:
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(new_app_mode).value,
|
||||
version="draft",
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account_id,
|
||||
|
||||
@ -105,7 +105,9 @@ class WorkflowService:
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft"
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@ -219,7 +221,7 @@ class WorkflowService:
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(app_model.mode).value,
|
||||
version="draft",
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account.id,
|
||||
@ -257,7 +259,7 @@ class WorkflowService:
|
||||
draft_workflow_stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == "draft",
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
draft_workflow = session.scalar(draft_workflow_stmt)
|
||||
if not draft_workflow:
|
||||
@ -383,9 +385,9 @@ class WorkflowService:
|
||||
tenant_id=app_model.tenant_id,
|
||||
)
|
||||
|
||||
eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
if eclosing_node_type_and_id:
|
||||
_, enclosing_node_id = eclosing_node_type_and_id
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
if enclosing_node_type_and_id:
|
||||
_, enclosing_node_id = enclosing_node_type_and_id
|
||||
else:
|
||||
enclosing_node_id = None
|
||||
|
||||
@ -645,7 +647,7 @@ class WorkflowService:
|
||||
raise ValueError(f"Workflow with ID {workflow_id} not found")
|
||||
|
||||
# Check if workflow is a draft version
|
||||
if workflow.version == "draft":
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
|
||||
|
||||
# Check if this workflow is currently referenced by an app
|
||||
|
||||
@ -32,6 +32,7 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
return
|
||||
|
||||
if dataset_document.indexing_status != "completed":
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = f"document_{dataset_document.id}_indexing"
|
||||
@ -112,3 +113,4 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
|
||||
@ -31,6 +31,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
|
||||
return
|
||||
|
||||
if segment.status != "waiting":
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
||||
@ -113,3 +113,5 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logging.exception("document_indexing_sync_task failed, document_id: %s", document_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@ -95,8 +95,8 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
logging.exception("retry_document_indexing_task failed, document_id: %s", document_id)
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
"retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
|
||||
|
||||
@ -11,7 +11,9 @@ class ElasticSearchVectorTest(AbstractVectorTest):
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self.vector = ElasticSearchVector(
|
||||
index_name=self.collection_name.lower(),
|
||||
config=ElasticSearchConfig(host="http://localhost", port="9200", username="elastic", password="elastic"),
|
||||
config=ElasticSearchConfig(
|
||||
use_cloud=False, host="http://localhost", port="9200", username="elastic", password="elastic"
|
||||
),
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from packaging.version import Version
|
||||
from yarl import URL
|
||||
@ -137,3 +138,61 @@ def test_db_extras_options_merging(monkeypatch):
|
||||
options = engine_options["connect_args"]["options"]
|
||||
assert "search_path=myschema" in options
|
||||
assert "timezone=UTC" in options
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"),
|
||||
[
|
||||
("redis://localhost:6379/1", "localhost", 6379, None, None, "1"),
|
||||
("redis://:password@localhost:6379/1", "localhost", 6379, None, "password", "1"),
|
||||
("redis://:mypass%23123@localhost:6379/1", "localhost", 6379, None, "mypass#123", "1"),
|
||||
("redis://user:pass%40word@redis-host:6380/2", "redis-host", 6380, "user", "pass@word", "2"),
|
||||
("redis://admin:complex%23pass%40word@127.0.0.1:6379/0", "127.0.0.1", 6379, "admin", "complex#pass@word", "0"),
|
||||
(
|
||||
"redis://user%40domain:secret%23123@redis.example.com:6380/3",
|
||||
"redis.example.com",
|
||||
6380,
|
||||
"user@domain",
|
||||
"secret#123",
|
||||
"3",
|
||||
),
|
||||
# Password containing %23 substring (double encoding scenario)
|
||||
("redis://:mypass%2523@localhost:6379/1", "localhost", 6379, None, "mypass%23", "1"),
|
||||
# Username and password both containing encoded characters
|
||||
("redis://user%2525%40:pass%2523@localhost:6379/1", "localhost", 6379, "user%25@", "pass%23", "1"),
|
||||
],
|
||||
)
|
||||
def test_celery_broker_url_with_special_chars_password(
|
||||
monkeypatch, broker_url, expected_host, expected_port, expected_username, expected_password, expected_db
|
||||
):
|
||||
"""Test that CELERY_BROKER_URL with various formats are handled correctly."""
|
||||
from kombu.utils.url import parse_url
|
||||
|
||||
# clear system environment variables
|
||||
os.environ.clear()
|
||||
|
||||
# Set up basic required environment variables (following existing pattern)
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
|
||||
# Set the CELERY_BROKER_URL to test
|
||||
monkeypatch.setenv("CELERY_BROKER_URL", broker_url)
|
||||
|
||||
# Create config and verify the URL is stored correctly
|
||||
config = DifyConfig()
|
||||
assert broker_url == config.CELERY_BROKER_URL
|
||||
|
||||
# Test actual parsing behavior using kombu's parse_url (same as production)
|
||||
redis_config = parse_url(config.CELERY_BROKER_URL)
|
||||
|
||||
# Verify the parsing results match expectations (using kombu's field names)
|
||||
assert redis_config["hostname"] == expected_host
|
||||
assert redis_config["port"] == expected_port
|
||||
assert redis_config["userid"] == expected_username # kombu uses 'userid' not 'username'
|
||||
assert redis_config["password"] == expected_password
|
||||
assert redis_config["virtual_host"] == expected_db # kombu uses 'virtual_host' not 'db'
|
||||
|
||||
@ -102,9 +102,14 @@ class TestPhoenixConfig:
|
||||
assert config.project == "default"
|
||||
|
||||
def test_endpoint_validation_with_path(self):
|
||||
"""Test endpoint validation normalizes URL by removing path"""
|
||||
config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1")
|
||||
assert config.endpoint == "https://custom.phoenix.com"
|
||||
"""Test endpoint validation with path"""
|
||||
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
|
||||
assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
|
||||
|
||||
def test_endpoint_validation_without_path(self):
|
||||
"""Test endpoint validation without path"""
|
||||
config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
|
||||
assert config.endpoint == "https://app.phoenix.arize.com"
|
||||
|
||||
|
||||
class TestLangfuseConfig:
|
||||
@ -118,7 +123,7 @@ class TestLangfuseConfig:
|
||||
assert config.host == "https://custom.langfuse.com"
|
||||
|
||||
def test_valid_config_with_path(self):
|
||||
host = host = "https://custom.langfuse.com/api/v1"
|
||||
host = "https://custom.langfuse.com/api/v1"
|
||||
config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
|
||||
assert config.public_key == "public_key"
|
||||
assert config.secret_key == "secret_key"
|
||||
@ -368,13 +373,15 @@ class TestConfigIntegration:
|
||||
"""Test that URL normalization works consistently across configs"""
|
||||
# Test that paths are removed from endpoints
|
||||
arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
|
||||
phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/")
|
||||
phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
|
||||
phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
|
||||
aliyun_config = AliyunConfig(
|
||||
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
|
||||
)
|
||||
|
||||
assert arize_config.endpoint == "https://arize.com"
|
||||
assert phoenix_config.endpoint == "https://phoenix.com"
|
||||
assert phoenix_with_path_config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
|
||||
assert phoenix_without_path_config.endpoint == "https://app.phoenix.arize.com"
|
||||
assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
|
||||
|
||||
def test_project_default_values(self):
|
||||
|
||||
189
api/tests/unit_tests/services/test_metadata_bug_complete.py
Normal file
189
api/tests/unit_tests/services/test_metadata_bug_complete.py
Normal file
@ -0,0 +1,189 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask_restful import reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class TestMetadataBugCompleteValidation:
|
||||
"""Complete test suite to verify the metadata nullable bug and its fix."""
|
||||
|
||||
def test_1_pydantic_layer_validation(self):
|
||||
"""Test Layer 1: Pydantic model validation correctly rejects None values."""
|
||||
# Pydantic should reject None values for required fields
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(type=None, name=None)
|
||||
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(type="string", name=None)
|
||||
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(type=None, name="test")
|
||||
|
||||
# Valid values should work
|
||||
valid_args = MetadataArgs(type="string", name="test_name")
|
||||
assert valid_args.type == "string"
|
||||
assert valid_args.name == "test_name"
|
||||
|
||||
def test_2_business_logic_layer_crashes_on_none(self):
|
||||
"""Test Layer 2: Business logic crashes when None values slip through."""
|
||||
# Create mock that bypasses Pydantic validation
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None
|
||||
mock_metadata_args.type = "string"
|
||||
|
||||
with patch("services.metadata_service.current_user") as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
# Should crash with TypeError
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
# Test update method as well
|
||||
with patch("services.metadata_service.current_user") as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
|
||||
|
||||
def test_3_database_constraints_verification(self):
|
||||
"""Test Layer 3: Verify database model has nullable=False constraints."""
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from models.dataset import DatasetMetadata
|
||||
|
||||
# Get table info
|
||||
mapper = inspect(DatasetMetadata)
|
||||
|
||||
# Check that type and name columns are not nullable
|
||||
type_column = mapper.columns["type"]
|
||||
name_column = mapper.columns["name"]
|
||||
|
||||
assert type_column.nullable is False, "type column should be nullable=False"
|
||||
assert name_column.nullable is False, "name column should be nullable=False"
|
||||
|
||||
def test_4_fixed_api_layer_rejects_null(self, app):
|
||||
"""Test Layer 4: Fixed API configuration properly rejects null values."""
|
||||
# Test Console API create endpoint (fixed)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
|
||||
# Test with just name being null
|
||||
with app.test_request_context(json={"type": "string", "name": None}, content_type="application/json"):
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
|
||||
# Test with just type being null
|
||||
with app.test_request_context(json={"type": None, "name": "test"}, content_type="application/json"):
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
|
||||
def test_5_fixed_api_accepts_valid_values(self, app):
|
||||
"""Test that fixed API still accepts valid non-null values."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"):
|
||||
args = parser.parse_args()
|
||||
assert args["type"] == "string"
|
||||
assert args["name"] == "valid_name"
|
||||
|
||||
def test_6_simulated_buggy_behavior(self, app):
|
||||
"""Test simulating the original buggy behavior with nullable=True."""
|
||||
# Simulate the old buggy configuration
|
||||
buggy_parser = reqparse.RequestParser()
|
||||
buggy_parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
buggy_parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
# This would pass in the buggy version
|
||||
args = buggy_parser.parse_args()
|
||||
assert args["type"] is None
|
||||
assert args["name"] is None
|
||||
|
||||
# But would crash when trying to create MetadataArgs
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(**args)
|
||||
|
||||
def test_7_end_to_end_validation_layers(self):
|
||||
"""Test all validation layers work together correctly."""
|
||||
# Layer 1: API should reject null at parameter level (with fix)
|
||||
# Layer 2: Pydantic should reject null at model level
|
||||
# Layer 3: Business logic expects non-null
|
||||
# Layer 4: Database enforces non-null
|
||||
|
||||
# Test that valid data flows through all layers
|
||||
valid_data = {"type": "string", "name": "test_metadata"}
|
||||
|
||||
# Should create valid Pydantic object
|
||||
metadata_args = MetadataArgs(**valid_data)
|
||||
assert metadata_args.type == "string"
|
||||
assert metadata_args.name == "test_metadata"
|
||||
|
||||
# Should not crash in business logic length check
|
||||
assert len(metadata_args.name) <= 255 # This should not crash
|
||||
assert len(metadata_args.type) > 0 # This should not crash
|
||||
|
||||
def test_8_verify_specific_fix_locations(self):
|
||||
"""Verify that the specific locations mentioned in bug report are fixed."""
|
||||
# Read the actual files to verify fixes
|
||||
import os
|
||||
|
||||
# Console API create
|
||||
console_create_file = "api/controllers/console/datasets/metadata.py"
|
||||
if os.path.exists(console_create_file):
|
||||
with open(console_create_file) as f:
|
||||
content = f.read()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
|
||||
|
||||
# Service API create
|
||||
service_create_file = "api/controllers/service_api/dataset/metadata.py"
|
||||
if os.path.exists(service_create_file):
|
||||
with open(service_create_file) as f:
|
||||
content = f.read()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
|
||||
assert "nullable=True" not in create_api_section
|
||||
|
||||
|
||||
class TestMetadataValidationSummary:
|
||||
"""Summary tests that demonstrate the complete validation architecture."""
|
||||
|
||||
def test_validation_layer_architecture(self):
|
||||
"""Document and test the 4-layer validation architecture."""
|
||||
# Layer 1: API Parameter Validation (Flask-RESTful reqparse)
|
||||
# - Role: First line of defense, validates HTTP request parameters
|
||||
# - Fixed: nullable=False ensures null values are rejected at API boundary
|
||||
|
||||
# Layer 2: Pydantic Model Validation
|
||||
# - Role: Validates data structure and types before business logic
|
||||
# - Working: Required fields without Optional[] reject None values
|
||||
|
||||
# Layer 3: Business Logic Validation
|
||||
# - Role: Domain-specific validation (length checks, uniqueness, etc.)
|
||||
# - Vulnerable: Direct len() calls crash on None values
|
||||
|
||||
# Layer 4: Database Constraints
|
||||
# - Role: Final data integrity enforcement
|
||||
# - Working: nullable=False prevents None values in database
|
||||
|
||||
# The bug was: Layer 1 allowed None, but Layers 2-4 expected non-None
|
||||
# The fix: Make Layer 1 consistent with Layers 2-4
|
||||
|
||||
assert True # This test documents the architecture
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
108
api/tests/unit_tests/services/test_metadata_nullable_bug.py
Normal file
108
api/tests/unit_tests/services/test_metadata_nullable_bug.py
Normal file
@ -0,0 +1,108 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask_restful import reqparse
|
||||
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class TestMetadataNullableBug:
|
||||
"""Test case to reproduce the metadata nullable validation bug."""
|
||||
|
||||
def test_metadata_args_with_none_values_should_fail(self):
|
||||
"""Test that MetadataArgs validation should reject None values."""
|
||||
# This test demonstrates the expected behavior - should fail validation
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
# This should fail because Pydantic expects non-None values
|
||||
MetadataArgs(type=None, name=None)
|
||||
|
||||
def test_metadata_service_create_with_none_name_crashes(self):
|
||||
"""Test that MetadataService.create_metadata crashes when name is None."""
|
||||
# Mock the MetadataArgs to bypass Pydantic validation
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None # This will cause len() to crash
|
||||
mock_metadata_args.type = "string"
|
||||
|
||||
with patch("services.metadata_service.current_user") as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
# This should crash with TypeError when calling len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
def test_metadata_service_update_with_none_name_crashes(self):
|
||||
"""Test that MetadataService.update_metadata_name crashes when name is None."""
|
||||
with patch("services.metadata_service.current_user") as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
# This should crash with TypeError when calling len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
|
||||
|
||||
def test_api_parser_accepts_null_values(self, app):
|
||||
"""Test that API parser configuration incorrectly accepts null values."""
|
||||
# Simulate the current API parser configuration
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
|
||||
# Simulate request data with null values
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
# This should parse successfully due to nullable=True
|
||||
args = parser.parse_args()
|
||||
|
||||
# Verify that null values are accepted
|
||||
assert args["type"] is None
|
||||
assert args["name"] is None
|
||||
|
||||
# This demonstrates the bug: API accepts None but business logic will crash
|
||||
|
||||
def test_integration_bug_scenario(self, app):
|
||||
"""Test the complete bug scenario from API to service layer."""
|
||||
# Step 1: API parser accepts null values (current buggy behavior)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
args = parser.parse_args()
|
||||
|
||||
# Step 2: Try to create MetadataArgs with None values
|
||||
# This should fail at Pydantic validation level
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
metadata_args = MetadataArgs(**args)
|
||||
|
||||
# Step 3: If we bypass Pydantic (simulating the bug scenario)
|
||||
# Move this outside the request context to avoid Flask-Login issues
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None # From args["name"]
|
||||
mock_metadata_args.type = None # From args["type"]
|
||||
|
||||
with patch("services.metadata_service.current_user") as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
# Step 4: Service layer crashes on len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
def test_correct_nullable_false_configuration_works(self, app):
|
||||
"""Test that the correct nullable=False configuration works as expected."""
|
||||
# This tests the FIXED configuration
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"):
|
||||
# This should fail with BadRequest due to nullable=False
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user