mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
chore(api): adapt Graphon 0.2.2 upgrade (#35377)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class ModelConfigConverter:
|
||||
|
||||
@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
|
||||
|
||||
|
||||
@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.helper.ssrf_proxy import graphon_ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from graphon.file.protocols import WorkflowFileRuntimeProtocol
|
||||
from graphon.file.runtime import set_workflow_file_runtime
|
||||
from graphon.http.protocols import HttpResponseProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.file import File
|
||||
@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
return dify_config.MULTIMODAL_SEND_FORMAT
|
||||
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
execution.exceptions_count = runtime_state.exceptions_count
|
||||
execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
|
||||
|
||||
def _update_node_execution(
|
||||
self,
|
||||
|
||||
@ -352,11 +352,11 @@ class DatasourceManager:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.CUSTOM,
|
||||
file_type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||
|
||||
@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -363,7 +363,7 @@ class ProviderConfiguration(BaseModel):
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
if key in provider_credential_secret_variables and isinstance(value, str):
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
@ -912,7 +912,7 @@ class ProviderConfiguration(BaseModel):
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
if key in provider_credential_secret_variables and isinstance(value, str):
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
|
||||
@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
|
||||
|
||||
@classmethod
|
||||
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
|
||||
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
|
||||
inputs_json_str = dumps_with_segments(inputs).encode()
|
||||
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
|
||||
return input_base64_encoded
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.tools.errors import ToolSSRFError
|
||||
from graphon.http.response import HttpResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -267,4 +268,47 @@ class SSRFProxy:
|
||||
return patch(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
|
||||
"""Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
|
||||
return HttpResponse(
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
content=response.content,
|
||||
url=str(response.url) if response.url else None,
|
||||
reason_phrase=response.reason_phrase,
|
||||
fallback_text=response.text,
|
||||
)
|
||||
|
||||
|
||||
class GraphonSSRFProxy:
|
||||
"""Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
|
||||
|
||||
@property
|
||||
def max_retries_exceeded_error(self) -> type[Exception]:
|
||||
return max_retries_exceeded_error
|
||||
|
||||
@property
|
||||
def request_error(self) -> type[Exception]:
|
||||
return request_error
|
||||
|
||||
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
|
||||
ssrf_proxy = SSRFProxy()
|
||||
graphon_ssrf_proxy = GraphonSSRFProxy()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ModelInstance:
|
||||
@ -168,7 +170,7 @@ class ModelInstance:
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
@ -193,7 +195,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
@ -213,7 +215,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
@ -235,7 +237,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
@ -252,7 +254,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
@ -277,7 +279,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@ -305,7 +307,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@ -324,7 +326,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
@ -340,7 +342,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
@ -357,14 +359,14 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
|
||||
@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
|
||||
provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
|
||||
)
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small_dark.zh_Hans
|
||||
provider_schema.icon_small_dark.zh_hans
|
||||
if lang.lower() == "zh_hans"
|
||||
else provider_schema.icon_small_dark.en_US
|
||||
else provider_schema.icon_small_dark.en_us
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class AgentHistoryPromptTransform(PromptTransform):
|
||||
|
||||
@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
_closed: bool
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str):
|
||||
"""Initialize with file path."""
|
||||
self._closed = False
|
||||
self.file_path = file_path
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError(f"File path {self.file_path} is not a valid file or url")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Best-effort cleanup for downloaded temporary files."""
|
||||
if getattr(self, "_closed", False):
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
temp_file = getattr(self, "temp_file", None)
|
||||
if temp_file is None:
|
||||
return
|
||||
|
||||
try:
|
||||
close_result = temp_file.close()
|
||||
if inspect.isawaitable(close_result):
|
||||
close_awaitable = getattr(close_result, "close", None)
|
||||
if callable(close_awaitable):
|
||||
close_awaitable()
|
||||
except Exception:
|
||||
logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "temp_file"):
|
||||
self.temp_file.close()
|
||||
self.close()
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
|
||||
@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
try:
|
||||
# Create File object directly (similar to DatasetRetrieval)
|
||||
file_obj = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(
|
||||
|
||||
@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.helper import parse_uuid_str_or_none
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models import UploadFile
|
||||
@ -517,11 +517,11 @@ class DatasetRetrieval:
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Any, Literal
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
|
||||
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
|
||||
|
||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
|
||||
@ -8,7 +8,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
BoundRecipient,
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryMethod,
|
||||
|
||||
@ -28,7 +28,7 @@ class ToolFileManager:
|
||||
def _build_graph_file_reference(tool_file: ToolFile) -> File:
|
||||
extension = guess_extension(tool_file.mimetype) or ".bin"
|
||||
return File(
|
||||
type=get_file_type_by_mime_type(tool_file.mimetype),
|
||||
file_type=get_file_type_by_mime_type(tool_file.mimetype),
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=tool_file.original_url,
|
||||
reference=build_file_reference(record_id=str(tool_file.id)),
|
||||
|
||||
@ -1082,7 +1082,12 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
variable_selector = tool_input.value
|
||||
if not isinstance(variable_selector, list) or not all(
|
||||
isinstance(selector_part, str) for selector_part in variable_selector
|
||||
):
|
||||
raise ToolParameterError("Variable tool input must be a variable selector")
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
|
||||
@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
|
||||
@ -357,7 +357,10 @@ class WorkflowTool(Tool):
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
transfer_method_value = file_dict.get("transfer_method")
|
||||
if not isinstance(transfer_method_value, str):
|
||||
raise ValueError("Workflow file mapping is missing a valid transfer_method")
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
match transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_id
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
"""Workflow-layer adapters for legacy human-input payload keys.
|
||||
"""Workflow-to-Graphon adapters for persisted node payloads.
|
||||
|
||||
Stored workflow graphs and editor payloads may still use Dify-specific human
|
||||
input recipient keys. Normalize them here before handing configs to
|
||||
`graphon` so graph-owned models only see graph-neutral field names.
|
||||
Stored workflow graphs and editor payloads still contain a small set of
|
||||
Dify-owned field spellings and value shapes. Adapt them here before handing the
|
||||
payload to Graphon so Graphon-owned models only see current contracts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_data)
|
||||
if normalized is None:
|
||||
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
|
||||
@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
|
||||
|
||||
|
||||
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
|
||||
normalized = normalize_human_input_node_data_for_graph(node_data)
|
||||
normalized = adapt_human_input_node_data_for_graph(node_data)
|
||||
raw_delivery_methods = normalized.get("delivery_methods")
|
||||
if not isinstance(raw_delivery_methods, list):
|
||||
return []
|
||||
@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
|
||||
return False
|
||||
|
||||
|
||||
def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_data)
|
||||
if normalized is None:
|
||||
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
|
||||
|
||||
if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
|
||||
return normalized
|
||||
return normalize_human_input_node_data_for_graph(normalized)
|
||||
node_type = normalized.get("type")
|
||||
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
|
||||
return adapt_human_input_node_data_for_graph(normalized)
|
||||
if node_type == BuiltinNodeTypes.TOOL:
|
||||
return _adapt_tool_node_data_for_graph(normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_config)
|
||||
if normalized is None:
|
||||
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
|
||||
@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
|
||||
if data_mapping is None:
|
||||
return normalized
|
||||
|
||||
normalized["data"] = normalize_node_data_for_graph(data_mapping)
|
||||
normalized["data"] = adapt_node_data_for_graph(data_mapping)
|
||||
return normalized
|
||||
|
||||
|
||||
def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(node_data)
|
||||
|
||||
raw_tool_configurations = normalized.get("tool_configurations")
|
||||
if not isinstance(raw_tool_configurations, Mapping):
|
||||
return normalized
|
||||
|
||||
existing_tool_parameters = normalized.get("tool_parameters")
|
||||
normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
|
||||
normalized_tool_configurations: dict[str, Any] = {}
|
||||
found_legacy_tool_inputs = False
|
||||
|
||||
for name, value in raw_tool_configurations.items():
|
||||
if not isinstance(value, Mapping):
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
input_type = value.get("type")
|
||||
input_value = value.get("value")
|
||||
if input_type not in {"mixed", "variable", "constant"}:
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
found_legacy_tool_inputs = True
|
||||
normalized_tool_parameters.setdefault(name, dict(value))
|
||||
|
||||
flattened_value = _flatten_legacy_tool_configuration_value(
|
||||
input_type=input_type,
|
||||
input_value=input_value,
|
||||
)
|
||||
if flattened_value is not None:
|
||||
normalized_tool_configurations[name] = flattened_value
|
||||
|
||||
if not found_legacy_tool_inputs:
|
||||
return normalized
|
||||
|
||||
normalized["tool_parameters"] = normalized_tool_parameters
|
||||
normalized["tool_configurations"] = normalized_tool_configurations
|
||||
return normalized
|
||||
|
||||
|
||||
def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
|
||||
if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
|
||||
return input_value
|
||||
|
||||
if (
|
||||
input_type == "variable"
|
||||
and isinstance(input_value, list)
|
||||
and all(isinstance(item, str) for item in input_value)
|
||||
):
|
||||
return "{{#" + ".".join(input_value) + "#}}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(recipients)
|
||||
|
||||
@ -291,9 +349,9 @@ __all__ = [
|
||||
"MemberRecipient",
|
||||
"WebAppDeliveryMethod",
|
||||
"_WebAppDeliveryConfig",
|
||||
"adapt_human_input_node_data_for_graph",
|
||||
"adapt_node_config_for_graph",
|
||||
"adapt_node_data_for_graph",
|
||||
"is_human_input_webapp_enabled",
|
||||
"normalize_human_input_node_data_for_graph",
|
||||
"normalize_node_config_for_graph",
|
||||
"normalize_node_data_for_graph",
|
||||
"parse_human_input_delivery_methods",
|
||||
]
|
||||
@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
)
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.helper.ssrf_proxy import graphon_ssrf_proxy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from core.workflow.human_input_compat import normalize_node_config_for_graph
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from core.workflow.node_runtime import (
|
||||
DifyFileReferenceFactory,
|
||||
DifyHumanInputNodeRuntime,
|
||||
@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.file.file_manager import file_manager
|
||||
from graphon.graph.graph import NodeFactory
|
||||
from graphon.model_runtime.memory import PromptMessageMemory
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.code.code_node import WorkflowCodeExecutor
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
|
||||
|
||||
|
||||
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
"""Resolve the production node class for the requested type/version."""
|
||||
node_mapping = get_node_type_classes_mapping().get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_http_client = graphon_ssrf_proxy
|
||||
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
|
||||
self._dify_context,
|
||||
conversation_id_getter=self._conversation_id,
|
||||
@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
(including pydantic ValidationError, which subclasses ValueError),
|
||||
if node type is unknown, or if no implementation exists for the resolved version
|
||||
"""
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
# Graph configs are initially validated against permissive shared node data.
|
||||
# Re-validate using the resolved node class so workflow-local node schemas
|
||||
# stay explicit and constructors receive the concrete typed payload.
|
||||
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
BuiltinNodeTypes.CODE: lambda: {
|
||||
@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=True,
|
||||
include_llm_file_saver=True,
|
||||
@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=True,
|
||||
include_llm_file_saver=True,
|
||||
@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
),
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=False,
|
||||
include_llm_file_saver=False,
|
||||
@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
|
||||
"""
|
||||
return node_class.validate_node_data(node_data)
|
||||
validate_node_data = getattr(node_class, "validate_node_data", None)
|
||||
if callable(validate_node_data):
|
||||
return cast("BaseNodeData", validate_node_data(node_data))
|
||||
return node_data
|
||||
|
||||
@staticmethod
|
||||
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.nodes.human_input.entities import HumanInputNodeData
|
||||
from graphon.nodes.llm.runtime_protocols import (
|
||||
PreparedLLMProtocol,
|
||||
@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .human_input_compat import (
|
||||
from .human_input_adapter import (
|
||||
BoundRecipient,
|
||||
DeliveryChannelConfig,
|
||||
DeliveryMethodType,
|
||||
@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
|
||||
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
|
||||
return self._model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: AgentNodeData,
|
||||
*,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
strategy_resolver: AgentStrategyResolver,
|
||||
presentation_provider: AgentStrategyPresentationProvider,
|
||||
runtime_support: AgentRuntimeSupport,
|
||||
message_transformer: AgentMessageTransformer,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_segment
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
NodeExecutionType,
|
||||
@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: DatasourceNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: KnowledgeIndexNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(id, config, graph_init_params, graph_runtime_state)
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self.index_processor = IndexProcessor()
|
||||
self.summary_index_service = SummaryIndex()
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@ -50,6 +49,18 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
|
||||
if value is None or isinstance(value, (str, float)):
|
||||
return value
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
def _normalize_metadata_filter_sequence_item(value: object) -> str:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: KnowledgeRetrievalNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
resolved_conditions: list[Condition] = []
|
||||
for cond in conditions.conditions or []:
|
||||
value = cond.value
|
||||
resolved_value: str | Sequence[str] | int | float | None
|
||||
if isinstance(value, str):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_value = segment_group.value[0].to_object()
|
||||
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_value = segment_group.text
|
||||
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values = []
|
||||
for v in value: # type: ignore
|
||||
resolved_values: list[str] = []
|
||||
for v in value:
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(segment_group.value[0].to_object())
|
||||
resolved_values.append(
|
||||
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
|
||||
)
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
|
||||
Reference in New Issue
Block a user