mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
Merge branch 'main' into feat/rag-2
This commit is contained in:
@ -140,7 +140,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = ChatAppGenerateEntity(
|
||||
|
||||
@ -124,7 +124,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
|
||||
@ -6,7 +6,6 @@ from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessage,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
@ -22,15 +21,6 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
|
||||
return MessageQueueMessage(
|
||||
task_id=self._task_id,
|
||||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event,
|
||||
)
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -5,7 +5,7 @@ from base64 import b64encode
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.variables.utils import SegmentJSONEncoder
|
||||
from core.variables.utils import dumps_with_segments
|
||||
|
||||
|
||||
class TemplateTransformer(ABC):
|
||||
@ -93,7 +93,7 @@ class TemplateTransformer(ABC):
|
||||
|
||||
@classmethod
|
||||
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
|
||||
inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode()
|
||||
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
|
||||
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
|
||||
return input_base64_encoded
|
||||
|
||||
|
||||
@ -16,15 +16,33 @@ def get_external_trace_id(request: Any) -> Optional[str]:
|
||||
"""
|
||||
Retrieve the trace_id from the request.
|
||||
|
||||
Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid.
|
||||
Priority:
|
||||
1. header ('X-Trace-Id')
|
||||
2. parameters
|
||||
3. JSON body
|
||||
4. Current OpenTelemetry context (if enabled)
|
||||
5. OpenTelemetry traceparent header (if present and valid)
|
||||
|
||||
Returns None if no valid trace_id is provided.
|
||||
"""
|
||||
trace_id = request.headers.get("X-Trace-Id")
|
||||
|
||||
if not trace_id:
|
||||
trace_id = request.args.get("trace_id")
|
||||
|
||||
if not trace_id and getattr(request, "is_json", False):
|
||||
json_data = getattr(request, "json", None)
|
||||
if json_data:
|
||||
trace_id = json_data.get("trace_id")
|
||||
|
||||
if not trace_id:
|
||||
trace_id = get_trace_id_from_otel_context()
|
||||
|
||||
if not trace_id:
|
||||
traceparent = request.headers.get("traceparent")
|
||||
if traceparent:
|
||||
trace_id = parse_traceparent_header(traceparent)
|
||||
|
||||
if isinstance(trace_id, str) and is_valid_trace_id(trace_id):
|
||||
return trace_id
|
||||
return None
|
||||
@ -40,3 +58,49 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
|
||||
if trace_id:
|
||||
return {"external_trace_id": trace_id}
|
||||
return {}
|
||||
|
||||
|
||||
def get_trace_id_from_otel_context() -> Optional[str]:
|
||||
"""
|
||||
Retrieve the current trace ID from the active OpenTelemetry trace context.
|
||||
Returns None if:
|
||||
1. OpenTelemetry SDK is not installed or enabled.
|
||||
2. There is no active span or trace context.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry.trace import SpanContext, get_current_span
|
||||
from opentelemetry.trace.span import INVALID_TRACE_ID
|
||||
|
||||
span = get_current_span()
|
||||
if not span:
|
||||
return None
|
||||
|
||||
span_context: SpanContext = span.get_span_context()
|
||||
|
||||
if not span_context or span_context.trace_id == INVALID_TRACE_ID:
|
||||
return None
|
||||
|
||||
trace_id_hex = f"{span_context.trace_id:032x}"
|
||||
return trace_id_hex
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def parse_traceparent_header(traceparent: str) -> Optional[str]:
|
||||
"""
|
||||
Parse the `traceparent` header to extract the trace_id.
|
||||
|
||||
Expected format:
|
||||
'version-trace_id-span_id-flags'
|
||||
|
||||
Reference:
|
||||
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
|
||||
"""
|
||||
try:
|
||||
parts = traceparent.split("-")
|
||||
if len(parts) == 4 and len(parts[1]) == 32:
|
||||
return parts[1]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
@ -10,8 +10,6 @@ from core.mcp.types import (
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
mcp_provider: MCPToolProvider
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any, TypeAlias, final
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, ServerSentEvent
|
||||
from sseclient import SSEClient
|
||||
|
||||
from core.mcp import types
|
||||
@ -37,11 +38,6 @@ WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
"""Remove request parameters from URL, keeping only the path."""
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
class SSETransport:
|
||||
"""SSE client transport implementation."""
|
||||
|
||||
@ -114,7 +110,7 @@ class SSETransport:
|
||||
logger.exception("Error parsing server message")
|
||||
read_queue.put(exc)
|
||||
|
||||
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Handle a single SSE event.
|
||||
|
||||
Args:
|
||||
@ -130,7 +126,7 @@ class SSETransport:
|
||||
case _:
|
||||
logger.warning("Unknown SSE event: %s", sse.event)
|
||||
|
||||
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Read and process SSE events.
|
||||
|
||||
Args:
|
||||
@ -225,7 +221,7 @@ class SSETransport:
|
||||
self,
|
||||
executor: ThreadPoolExecutor,
|
||||
client: httpx.Client,
|
||||
event_source,
|
||||
event_source: EventSource,
|
||||
) -> tuple[ReadQueue, WriteQueue]:
|
||||
"""Establish connection and start worker threads.
|
||||
|
||||
|
||||
@ -16,13 +16,14 @@ from extensions.ext_database import db
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerStreamableHTTPRequestHandler:
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
||||
):
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
import httpx
|
||||
import httpx_sse
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import ErrorData, JSONRPCError
|
||||
@ -55,20 +59,42 @@ def create_ssrf_proxy_mcp_http_client(
|
||||
)
|
||||
|
||||
|
||||
def ssrf_proxy_sse_connect(url, **kwargs):
|
||||
def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]:
|
||||
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||
|
||||
This function creates an SSE connection using the configured proxy settings
|
||||
to prevent SSRF attacks when connecting to external endpoints.
|
||||
to prevent SSRF attacks when connecting to external endpoints. It returns
|
||||
a context manager that yields an EventSource object for SSE streaming.
|
||||
|
||||
The function handles HTTP client creation and cleanup automatically, but
|
||||
also accepts a pre-configured client via kwargs.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL
|
||||
**kwargs: Additional arguments passed to the SSE connection
|
||||
url (str): The SSE endpoint URL to connect to
|
||||
**kwargs: Additional arguments passed to the SSE connection, including:
|
||||
- client (httpx.Client, optional): Pre-configured HTTP client.
|
||||
If not provided, one will be created with SSRF protection.
|
||||
- method (str, optional): HTTP method to use, defaults to "GET"
|
||||
- headers (dict, optional): HTTP headers to include in the request
|
||||
- timeout (httpx.Timeout, optional): Timeout configuration for the connection
|
||||
|
||||
Returns:
|
||||
EventSource object for SSE streaming
|
||||
AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource
|
||||
object for SSE streaming. The EventSource provides access to server-sent events.
|
||||
|
||||
Example:
|
||||
```python
|
||||
with ssrf_proxy_sse_connect(url, headers=headers) as event_source:
|
||||
for sse in event_source.iter_sse():
|
||||
print(sse.event, sse.data)
|
||||
```
|
||||
|
||||
Note:
|
||||
If a client is not provided in kwargs, one will be automatically created
|
||||
with SSRF protection based on the application's configuration. If an
|
||||
exception occurs during connection, any automatically created client
|
||||
will be cleaned up automatically.
|
||||
"""
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
# Extract client if provided, otherwise create one
|
||||
client = kwargs.pop("client", None)
|
||||
@ -101,7 +127,9 @@ def ssrf_proxy_sse_connect(url, **kwargs):
|
||||
raise
|
||||
|
||||
|
||||
def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
|
||||
def create_mcp_error_response(
|
||||
request_id: int | str | None, code: int, message: str, data=None
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""Create MCP error response"""
|
||||
error_data = ErrorData(code=code, message=message, data=data)
|
||||
json_response = JSONRPCError(
|
||||
|
||||
@ -151,12 +151,9 @@ def jsonable_encoder(
|
||||
return format(obj, "f")
|
||||
if isinstance(obj, dict):
|
||||
encoded_dict = {}
|
||||
allowed_keys = set(obj.keys())
|
||||
for key, value in obj.items():
|
||||
if (
|
||||
(not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa")))
|
||||
and (value is not None or not exclude_none)
|
||||
and key in allowed_keys
|
||||
if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (
|
||||
value is not None or not exclude_none
|
||||
):
|
||||
encoded_key = jsonable_encoder(
|
||||
key,
|
||||
|
||||
@ -4,15 +4,15 @@ from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_string_to_id,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
create_link,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
@ -103,10 +103,11 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
|
||||
self.add_workflow_span(trace_id, workflow_span_id, trace_info, links)
|
||||
|
||||
workflow_node_executions = self.get_workflow_node_executions(trace_info)
|
||||
for node_execution in workflow_node_executions:
|
||||
@ -132,8 +133,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
|
||||
message_span_id = convert_to_span_id(message_id, "message")
|
||||
message_span = SpanData(
|
||||
@ -152,6 +154,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
OUTPUT_VALUE: str(trace_info.outputs),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
@ -192,8 +195,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
message_id = trace_info.message_id
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
|
||||
documents_data = extract_retrieval_documents(trace_info.documents)
|
||||
dataset_retrieval_span = SpanData(
|
||||
@ -211,6 +215,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
INPUT_VALUE: str(trace_info.inputs),
|
||||
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
|
||||
},
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(dataset_retrieval_span)
|
||||
|
||||
@ -224,8 +229,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
|
||||
tool_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
@ -244,6 +250,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
OUTPUT_VALUE: str(trace_info.tool_outputs),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(tool_span)
|
||||
|
||||
@ -413,7 +420,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
|
||||
def add_workflow_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, links: Sequence[Link]
|
||||
):
|
||||
message_span_id = None
|
||||
if trace_info.message_id:
|
||||
message_span_id = convert_to_span_id(trace_info.message_id, "message")
|
||||
@ -438,6 +447,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
@ -456,6 +466,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
@ -466,8 +477,9 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
trace_id = convert_string_to_id(trace_info.trace_id)
|
||||
links.append(create_link(trace_id_str=trace_info.trace_id))
|
||||
|
||||
suggested_question_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
@ -487,6 +499,7 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
self.trace_client.add_span(suggested_question_span)
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
@ -166,6 +167,16 @@ class SpanBuilder:
|
||||
return span
|
||||
|
||||
|
||||
def create_link(trace_id_str: str) -> Link:
|
||||
placeholder_span_id = 0x0000000000000000
|
||||
trace_id = int(trace_id_str, 16)
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED)
|
||||
)
|
||||
|
||||
return Link(span_context)
|
||||
|
||||
|
||||
def generate_span_id() -> int:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
|
||||
@ -523,7 +523,7 @@ class ProviderManager:
|
||||
# Init trial provider records if not exists
|
||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||
try:
|
||||
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
|
||||
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
|
||||
new_provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
@ -135,13 +135,13 @@ class Jieba(BaseKeyword):
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type
|
||||
if keyword_data_source_type == "database":
|
||||
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
|
||||
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
|
||||
db.session.commit()
|
||||
else:
|
||||
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
|
||||
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
|
||||
|
||||
def _get_dataset_keyword_table(self) -> Optional[dict]:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
@ -157,12 +157,11 @@ class Jieba(BaseKeyword):
|
||||
data_source_type=keyword_data_source_type,
|
||||
)
|
||||
if keyword_data_source_type == "database":
|
||||
dataset_keyword_table.keyword_table = json.dumps(
|
||||
dataset_keyword_table.keyword_table = dumps_with_sets(
|
||||
{
|
||||
"__type__": "keyword_table",
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
|
||||
},
|
||||
cls=SetEncoder,
|
||||
}
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
@ -257,8 +256,13 @@ class Jieba(BaseKeyword):
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
|
||||
class SetEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
return super().default(obj)
|
||||
def set_orjson_default(obj: Any) -> Any:
|
||||
"""Default function for orjson serialization of set types"""
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
def dumps_with_sets(obj: Any) -> str:
|
||||
"""JSON dumps with set support using orjson"""
|
||||
return orjson.dumps(obj, default=set_orjson_default).decode("utf-8")
|
||||
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -108,10 +108,18 @@ class ApiProviderAuthType(Enum):
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
# 'api_key' deprecated in PR #21656
|
||||
# normalize & tiny alias for backward compatibility
|
||||
v = (value or "").strip().lower()
|
||||
if v == "api_key":
|
||||
v = cls.API_KEY_HEADER.value
|
||||
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
if mode.value == v:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
valid = ", ".join(m.value for m in cls)
|
||||
raise ValueError(f"invalid mode value '{value}', expected one of: {valid}")
|
||||
|
||||
|
||||
class ToolInvokeMessage(BaseModel):
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import json
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import ArrayFileSegment, FileSegment, Segment
|
||||
@ -12,15 +14,20 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
|
||||
return selectors
|
||||
|
||||
|
||||
class SegmentJSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, ArrayFileSegment):
|
||||
return [v.model_dump() for v in o.value]
|
||||
elif isinstance(o, FileSegment):
|
||||
return o.value.model_dump()
|
||||
elif isinstance(o, SegmentGroup):
|
||||
return [self.default(seg) for seg in o.value]
|
||||
elif isinstance(o, Segment):
|
||||
return o.value
|
||||
else:
|
||||
super().default(o)
|
||||
def segment_orjson_default(o: Any) -> Any:
|
||||
"""Default function for orjson serialization of Segment types"""
|
||||
if isinstance(o, ArrayFileSegment):
|
||||
return [v.model_dump() for v in o.value]
|
||||
elif isinstance(o, FileSegment):
|
||||
return o.value.model_dump()
|
||||
elif isinstance(o, SegmentGroup):
|
||||
return [segment_orjson_default(seg) for seg in o.value]
|
||||
elif isinstance(o, Segment):
|
||||
return o.value
|
||||
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str:
|
||||
"""JSON dumps with segment support using orjson"""
|
||||
option = orjson.OPT_NON_STR_KEYS
|
||||
return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8")
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
@ -194,17 +194,6 @@ class LLMNode(BaseNode):
|
||||
else []
|
||||
)
|
||||
|
||||
# single step run fetch file from sys files
|
||||
if not files and self.invoke_from == InvokeFrom.DEBUGGER and not self.previous_node_id:
|
||||
files = (
|
||||
llm_utils.fetch_files(
|
||||
variable_pool=variable_pool,
|
||||
selector=["sys", "files"],
|
||||
)
|
||||
if self._node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
if files:
|
||||
node_inputs["#files#"] = [file.to_dict() for file in files]
|
||||
|
||||
|
||||
@ -318,33 +318,6 @@ class ToolNode(BaseNode):
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": message.message.text,
|
||||
}
|
||||
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
|
||||
Reference in New Issue
Block a user