Merge branch 'main' into feat/rag-2

This commit is contained in:
twwu
2025-08-18 11:16:18 +08:00
99 changed files with 3421 additions and 1810 deletions

View File

@ -140,7 +140,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(

View File

@ -124,7 +124,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
trace_manager = TraceQueueManager(app_model.id)
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
# init application generate entity
application_generate_entity = CompletionAppGenerateEntity(

View File

@ -6,7 +6,6 @@ from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueErrorEvent,
QueueMessage,
QueueMessageEndEvent,
QueueStopEvent,
)
@ -22,15 +21,6 @@ class MessageBasedAppQueueManager(AppQueueManager):
self._app_mode = app_mode
self._message_id = str(message_id)
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
return MessageQueueMessage(
task_id=self._task_id,
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event,
)
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue

View File

@ -5,7 +5,7 @@ from base64 import b64encode
from collections.abc import Mapping
from typing import Any
from core.variables.utils import SegmentJSONEncoder
from core.variables.utils import dumps_with_segments
class TemplateTransformer(ABC):
@ -93,7 +93,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode()
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded

View File

@ -16,15 +16,33 @@ def get_external_trace_id(request: Any) -> Optional[str]:
"""
Retrieve the trace_id from the request.
Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid.
Priority:
1. header ('X-Trace-Id')
2. parameters
3. JSON body
4. Current OpenTelemetry context (if enabled)
5. OpenTelemetry traceparent header (if present and valid)
Returns None if no valid trace_id is provided.
"""
trace_id = request.headers.get("X-Trace-Id")
if not trace_id:
trace_id = request.args.get("trace_id")
if not trace_id and getattr(request, "is_json", False):
json_data = getattr(request, "json", None)
if json_data:
trace_id = json_data.get("trace_id")
if not trace_id:
trace_id = get_trace_id_from_otel_context()
if not trace_id:
traceparent = request.headers.get("traceparent")
if traceparent:
trace_id = parse_traceparent_header(traceparent)
if isinstance(trace_id, str) and is_valid_trace_id(trace_id):
return trace_id
return None
@ -40,3 +58,49 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
if trace_id:
return {"external_trace_id": trace_id}
return {}
def get_trace_id_from_otel_context() -> Optional[str]:
"""
Retrieve the current trace ID from the active OpenTelemetry trace context.
Returns None if:
1. OpenTelemetry SDK is not installed or enabled.
2. There is no active span or trace context.
"""
try:
from opentelemetry.trace import SpanContext, get_current_span
from opentelemetry.trace.span import INVALID_TRACE_ID
span = get_current_span()
if not span:
return None
span_context: SpanContext = span.get_span_context()
if not span_context or span_context.trace_id == INVALID_TRACE_ID:
return None
trace_id_hex = f"{span_context.trace_id:032x}"
return trace_id_hex
except Exception:
return None
def parse_traceparent_header(traceparent: str) -> Optional[str]:
"""
Parse the `traceparent` header to extract the trace_id.
Expected format:
'version-trace_id-span_id-flags'
Reference:
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
"""
try:
parts = traceparent.split("-")
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
except Exception:
pass
return None

View File

@ -10,8 +10,6 @@ from core.mcp.types import (
from models.tools import MCPToolProvider
from services.tools.mcp_tools_manage_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"
class OAuthClientProvider:
mcp_provider: MCPToolProvider

View File

@ -7,6 +7,7 @@ from typing import Any, TypeAlias, final
from urllib.parse import urljoin, urlparse
import httpx
from httpx_sse import EventSource, ServerSentEvent
from sseclient import SSEClient
from core.mcp import types
@ -37,11 +38,6 @@ WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
def remove_request_params(url: str) -> str:
"""Remove request parameters from URL, keeping only the path."""
return urljoin(url, urlparse(url).path)
class SSETransport:
"""SSE client transport implementation."""
@ -114,7 +110,7 @@ class SSETransport:
logger.exception("Error parsing server message")
read_queue.put(exc)
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
"""Handle a single SSE event.
Args:
@ -130,7 +126,7 @@ class SSETransport:
case _:
logger.warning("Unknown SSE event: %s", sse.event)
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
"""Read and process SSE events.
Args:
@ -225,7 +221,7 @@ class SSETransport:
self,
executor: ThreadPoolExecutor,
client: httpx.Client,
event_source,
event_source: EventSource,
) -> tuple[ReadQueue, WriteQueue]:
"""Establish connection and start worker threads.

View File

@ -16,13 +16,14 @@ from extensions.ext_database import db
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService
"""
Apply to MCP HTTP streamable server with stateless http
"""
logger = logging.getLogger(__name__)
class MCPServerStreamableHTTPRequestHandler:
"""
Apply to MCP HTTP streamable server with stateless http
"""
def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):

View File

@ -1,6 +1,10 @@
import json
from collections.abc import Generator
from contextlib import AbstractContextManager
import httpx
import httpx_sse
from httpx_sse import connect_sse
from configs import dify_config
from core.mcp.types import ErrorData, JSONRPCError
@ -55,20 +59,42 @@ def create_ssrf_proxy_mcp_http_client(
)
def ssrf_proxy_sse_connect(url, **kwargs):
def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]:
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
to prevent SSRF attacks when connecting to external endpoints. It returns
a context manager that yields an EventSource object for SSE streaming.
The function handles HTTP client creation and cleanup automatically, but
also accepts a pre-configured client via kwargs.
Args:
url: The SSE endpoint URL
**kwargs: Additional arguments passed to the SSE connection
url (str): The SSE endpoint URL to connect to
**kwargs: Additional arguments passed to the SSE connection, including:
- client (httpx.Client, optional): Pre-configured HTTP client.
If not provided, one will be created with SSRF protection.
- method (str, optional): HTTP method to use, defaults to "GET"
- headers (dict, optional): HTTP headers to include in the request
- timeout (httpx.Timeout, optional): Timeout configuration for the connection
Returns:
EventSource object for SSE streaming
AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource
object for SSE streaming. The EventSource provides access to server-sent events.
Example:
```python
with ssrf_proxy_sse_connect(url, headers=headers) as event_source:
for sse in event_source.iter_sse():
print(sse.event, sse.data)
```
Note:
If a client is not provided in kwargs, one will be automatically created
with SSRF protection based on the application's configuration. If an
exception occurs during connection, any automatically created client
will be cleaned up automatically.
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
@ -101,7 +127,9 @@ def ssrf_proxy_sse_connect(url, **kwargs):
raise
def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
def create_mcp_error_response(
request_id: int | str | None, code: int, message: str, data=None
) -> Generator[bytes, None, None]:
"""Create MCP error response"""
error_data = ErrorData(code=code, message=message, data=data)
json_response = JSONRPCError(

View File

@ -151,12 +151,9 @@ def jsonable_encoder(
return format(obj, "f")
if isinstance(obj, dict):
encoded_dict = {}
allowed_keys = set(obj.keys())
for key, value in obj.items():
if (
(not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa")))
and (value is not None or not exclude_none)
and key in allowed_keys
if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (
value is not None or not exclude_none
):
encoded_key = jsonable_encoder(
key,

View File

@ -4,15 +4,15 @@ from collections.abc import Sequence
from typing import Optional
from urllib.parse import urljoin
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
TraceClient,
convert_datetime_to_nanoseconds,
convert_string_to_id,
convert_to_span_id,
convert_to_trace_id,
create_link,
generate_span_id,
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
@ -103,10 +103,11 @@ class AliyunDataTrace(BaseTraceInstance):
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
links = []
if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id)
links.append(create_link(trace_id_str=trace_info.trace_id))
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
self.add_workflow_span(trace_id, workflow_span_id, trace_info, links)
workflow_node_executions = self.get_workflow_node_executions(trace_info)
for node_execution in workflow_node_executions:
@ -132,8 +133,9 @@ class AliyunDataTrace(BaseTraceInstance):
status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id)
links.append(create_link(trace_id_str=trace_info.trace_id))
message_span_id = convert_to_span_id(message_id, "message")
message_span = SpanData(
@ -152,6 +154,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: str(trace_info.outputs),
},
status=status,
links=links,
)
self.trace_client.add_span(message_span)
@ -192,8 +195,9 @@ class AliyunDataTrace(BaseTraceInstance):
message_id = trace_info.message_id
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id)
links.append(create_link(trace_id_str=trace_info.trace_id))
documents_data = extract_retrieval_documents(trace_info.documents)
dataset_retrieval_span = SpanData(
@ -211,6 +215,7 @@ class AliyunDataTrace(BaseTraceInstance):
INPUT_VALUE: str(trace_info.inputs),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
},
links=links,
)
self.trace_client.add_span(dataset_retrieval_span)
@ -224,8 +229,9 @@ class AliyunDataTrace(BaseTraceInstance):
status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id)
links.append(create_link(trace_id_str=trace_info.trace_id))
tool_span = SpanData(
trace_id=trace_id,
@ -244,6 +250,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: str(trace_info.tool_outputs),
},
status=status,
links=links,
)
self.trace_client.add_span(tool_span)
@ -413,7 +420,9 @@ class AliyunDataTrace(BaseTraceInstance):
status=self.get_workflow_node_status(node_execution),
)
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
def add_workflow_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, links: Sequence[Link]
):
message_span_id = None
if trace_info.message_id:
message_span_id = convert_to_span_id(trace_info.message_id, "message")
@ -438,6 +447,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
links=links,
)
self.trace_client.add_span(message_span)
@ -456,6 +466,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
links=links,
)
self.trace_client.add_span(workflow_span)
@ -466,8 +477,9 @@ class AliyunDataTrace(BaseTraceInstance):
status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
trace_id = convert_string_to_id(trace_info.trace_id)
links.append(create_link(trace_id_str=trace_info.trace_id))
suggested_question_span = SpanData(
trace_id=trace_id,
@ -487,6 +499,7 @@ class AliyunDataTrace(BaseTraceInstance):
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
},
status=status,
links=links,
)
self.trace_client.add_span(suggested_question_span)

View File

@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
@ -166,6 +167,16 @@ class SpanBuilder:
return span
def create_link(trace_id_str: str) -> Link:
placeholder_span_id = 0x0000000000000000
trace_id = int(trace_id_str, 16)
span_context = SpanContext(
trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED)
)
return Link(span_context)
def generate_span_id() -> int:
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:

View File

@ -523,7 +523,7 @@ class ProviderManager:
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
new_provider_record = Provider(
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.

View File

@ -1,7 +1,7 @@
import json
from collections import defaultdict
from typing import Any, Optional
import orjson
from pydantic import BaseModel
from configs import dify_config
@ -135,13 +135,13 @@ class Jieba(BaseKeyword):
dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit()
else:
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
@ -157,12 +157,11 @@ class Jieba(BaseKeyword):
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(
dataset_keyword_table.keyword_table = dumps_with_sets(
{
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
},
cls=SetEncoder,
}
)
db.session.add(dataset_keyword_table)
db.session.commit()
@ -257,8 +256,13 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table)
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return super().default(obj)
def set_orjson_default(obj: Any) -> Any:
"""Default function for orjson serialization of set types"""
if isinstance(obj, set):
return list(obj)
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
def dumps_with_sets(obj: Any) -> str:
"""JSON dumps with set support using orjson"""
return orjson.dumps(obj, default=set_orjson_default).decode("utf-8")

View File

@ -1 +0,0 @@

View File

@ -108,10 +108,18 @@ class ApiProviderAuthType(Enum):
:param value: mode value
:return: mode
"""
# 'api_key' deprecated in PR #21656
# normalize & tiny alias for backward compatibility
v = (value or "").strip().lower()
if v == "api_key":
v = cls.API_KEY_HEADER.value
for mode in cls:
if mode.value == value:
if mode.value == v:
return mode
raise ValueError(f"invalid mode value {value}")
valid = ", ".join(m.value for m in cls)
raise ValueError(f"invalid mode value '{value}', expected one of: {valid}")
class ToolInvokeMessage(BaseModel):

View File

@ -1,5 +1,7 @@
import json
from collections.abc import Iterable, Sequence
from typing import Any
import orjson
from .segment_group import SegmentGroup
from .segments import ArrayFileSegment, FileSegment, Segment
@ -12,15 +14,20 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
return selectors
class SegmentJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
elif isinstance(o, SegmentGroup):
return [self.default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
else:
super().default(o)
def segment_orjson_default(o: Any) -> Any:
"""Default function for orjson serialization of Segment types"""
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
elif isinstance(o, SegmentGroup):
return [segment_orjson_default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")
def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str:
"""JSON dumps with segment support using orjson"""
option = orjson.OPT_NON_STR_KEYS
return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8")

View File

@ -5,7 +5,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
@ -194,17 +194,6 @@ class LLMNode(BaseNode):
else []
)
# single step run fetch file from sys files
if not files and self.invoke_from == InvokeFrom.DEBUGGER and not self.previous_node_id:
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=["sys", "files"],
)
if self._node_data.vision.enabled
else []
)
if files:
node_inputs["#files#"] = [file.to_dict() for file in files]

View File

@ -318,33 +318,6 @@ class ToolNode(BaseNode):
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": message.message.text,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])