mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 10:17:26 +08:00
Compare commits
21 Commits
1.13.3
...
yanli/phas
| Author | SHA1 | Date | |
|---|---|---|---|
| b7a5ed6c0b | |||
| e819a9a5f7 | |||
| bc82676d93 | |||
| 7b76fdc1d3 | |||
| 82acddddb4 | |||
| 710ac3b90a | |||
| 8548498f25 | |||
| d014f0b91a | |||
| cc5aac268a | |||
| 4c1d27431b | |||
| 9a86f280eb | |||
| c5920fb28a | |||
| 2f81d5dfdf | |||
| 7639d8e43f | |||
| 1dce81c604 | |||
| f874ca183e | |||
| 0d805e624e | |||
| 61196180b8 | |||
| 79433b0091 | |||
| c4aeaa35d4 | |||
| 9f0d79b8b0 |
@ -5,7 +5,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
|
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@ -47,7 +47,6 @@ from extensions.ext_database import db
|
|||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from libs.flask_utils import preserve_flask_contexts
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
from models.base import Base
|
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.workflow_draft_variable_service import (
|
from services.workflow_draft_variable_service import (
|
||||||
@ -522,8 +521,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
|
|
||||||
# release database connection, because the following new thread operations may take a long time
|
# release database connection, because the following new thread operations may take a long time
|
||||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
workflow = _refresh_model(session, workflow)
|
workflow = _refresh_model(session=session, model=workflow)
|
||||||
message = _refresh_model(session, message)
|
message = _refresh_model(session=session, model=message)
|
||||||
|
if message is None:
|
||||||
|
raise RuntimeError("Failed to refresh Message; _refresh_model returned None.")
|
||||||
# workflow_ = session.get(Workflow, workflow.id)
|
# workflow_ = session.get(Workflow, workflow.id)
|
||||||
# assert workflow_ is not None
|
# assert workflow_ is not None
|
||||||
# workflow = workflow_
|
# workflow = workflow_
|
||||||
@ -690,11 +691,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T", bound=Base)
|
@overload
|
||||||
|
def _refresh_model(*, session: Session | None = None, model: Workflow) -> Workflow: ...
|
||||||
|
|
||||||
|
|
||||||
def _refresh_model(session, model: _T) -> _T:
|
@overload
|
||||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
def _refresh_model(*, session: Session | None = None, model: Message) -> Message: ...
|
||||||
detach_model = session.get(type(model), model.id)
|
|
||||||
assert detach_model is not None
|
|
||||||
return detach_model
|
def _refresh_model(*, session: Session | None = None, model: Any) -> Any:
|
||||||
|
if session is not None:
|
||||||
|
detached_model = session.get(type(model), model.id)
|
||||||
|
assert detached_model is not None
|
||||||
|
return detached_model
|
||||||
|
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as refresh_session:
|
||||||
|
detached_model = refresh_session.get(type(model), model.id)
|
||||||
|
assert detached_model is not None
|
||||||
|
return detached_model
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator, Iterator
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
@ -56,8 +56,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_full_response(
|
def convert_stream_full_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, Any, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream full response.
|
Convert stream full response.
|
||||||
:param stream_response: stream response
|
:param stream_response: stream response
|
||||||
@ -87,8 +87,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_simple_response(
|
def convert_stream_simple_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, Any, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream simple response.
|
Convert stream simple response.
|
||||||
:param stream_response: stream response
|
:param stream_response: stream response
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator, Iterator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
@ -55,7 +55,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_full_response(
|
def convert_stream_full_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream full response.
|
Convert stream full response.
|
||||||
@ -86,7 +86,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_simple_response(
|
def convert_stream_simple_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream simple response.
|
Convert stream simple response.
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Iterator, Mapping
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||||
@ -16,24 +16,26 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert(
|
def convert(
|
||||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
cls, response: AppBlockingResponse | Iterator[AppStreamResponse], invoke_from: InvokeFrom
|
||||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||||
if isinstance(response, AppBlockingResponse):
|
if isinstance(response, AppBlockingResponse):
|
||||||
return cls.convert_blocking_full_response(response)
|
return cls.convert_blocking_full_response(response)
|
||||||
else:
|
else:
|
||||||
|
stream_response = response
|
||||||
|
|
||||||
def _generate_full_response() -> Generator[dict | str, Any, None]:
|
def _generate_full_response() -> Generator[dict[str, Any] | str, None, None]:
|
||||||
yield from cls.convert_stream_full_response(response)
|
yield from cls.convert_stream_full_response(stream_response)
|
||||||
|
|
||||||
return _generate_full_response()
|
return _generate_full_response()
|
||||||
else:
|
else:
|
||||||
if isinstance(response, AppBlockingResponse):
|
if isinstance(response, AppBlockingResponse):
|
||||||
return cls.convert_blocking_simple_response(response)
|
return cls.convert_blocking_simple_response(response)
|
||||||
else:
|
else:
|
||||||
|
stream_response = response
|
||||||
|
|
||||||
def _generate_simple_response() -> Generator[dict | str, Any, None]:
|
def _generate_simple_response() -> Generator[dict[str, Any] | str, None, None]:
|
||||||
yield from cls.convert_stream_simple_response(response)
|
yield from cls.convert_stream_simple_response(stream_response)
|
||||||
|
|
||||||
return _generate_simple_response()
|
return _generate_simple_response()
|
||||||
|
|
||||||
@ -50,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_stream_full_response(
|
def convert_stream_full_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_stream_simple_response(
|
def convert_stream_simple_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -224,6 +224,7 @@ class BaseAppGenerator:
|
|||||||
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
|
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
|
||||||
if invoke_from == InvokeFrom.DEBUGGER:
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
assert isinstance(account, Account)
|
assert isinstance(account, Account)
|
||||||
|
debug_account = account
|
||||||
|
|
||||||
def draft_var_saver_factory(
|
def draft_var_saver_factory(
|
||||||
session: Session,
|
session: Session,
|
||||||
@ -240,7 +241,7 @@ class BaseAppGenerator:
|
|||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_execution_id=node_execution_id,
|
node_execution_id=node_execution_id,
|
||||||
enclosing_node_id=enclosing_node_id,
|
enclosing_node_id=enclosing_node_id,
|
||||||
user=account,
|
user=debug_account,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
|||||||
@ -166,15 +166,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
|
|
||||||
# init generate records
|
# init generate records
|
||||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||||
|
if conversation is None or message is None:
|
||||||
|
raise RuntimeError("_init_generate_records() returned None for conversation or message")
|
||||||
|
generated_conversation_id = str(conversation.id)
|
||||||
|
generated_message_id = str(message.id)
|
||||||
|
|
||||||
# init queue manager
|
# init queue manager
|
||||||
queue_manager = MessageBasedAppQueueManager(
|
queue_manager = MessageBasedAppQueueManager(
|
||||||
task_id=application_generate_entity.task_id,
|
task_id=application_generate_entity.task_id,
|
||||||
user_id=application_generate_entity.user_id,
|
user_id=application_generate_entity.user_id,
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
conversation_id=conversation.id,
|
conversation_id=generated_conversation_id,
|
||||||
app_mode=conversation.mode,
|
app_mode=conversation.mode,
|
||||||
message_id=message.id,
|
message_id=generated_message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# new thread with request context
|
# new thread with request context
|
||||||
@ -184,8 +188,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
flask_app=current_app._get_current_object(), # type: ignore
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
conversation_id=conversation.id,
|
conversation_id=generated_conversation_id,
|
||||||
message_id=message.id,
|
message_id=generated_message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_thread = threading.Thread(target=worker_with_context)
|
worker_thread = threading.Thread(target=worker_with_context)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator, Iterator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
@ -55,7 +55,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_full_response(
|
def convert_stream_full_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream full response.
|
Convert stream full response.
|
||||||
@ -86,7 +86,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_simple_response(
|
def convert_stream_simple_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream simple response.
|
Convert stream simple response.
|
||||||
|
|||||||
@ -149,6 +149,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
|
|
||||||
# init generate records
|
# init generate records
|
||||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||||
|
if conversation is None or message is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"_init_generate_records() returned None for conversation or message, "
|
||||||
|
"which is required to proceed with generation."
|
||||||
|
)
|
||||||
|
|
||||||
# init queue manager
|
# init queue manager
|
||||||
queue_manager = MessageBasedAppQueueManager(
|
queue_manager = MessageBasedAppQueueManager(
|
||||||
@ -312,15 +317,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
|
|
||||||
# init generate records
|
# init generate records
|
||||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||||
|
assert conversation is not None
|
||||||
|
assert message is not None
|
||||||
|
conversation_id = str(conversation.id)
|
||||||
|
message_id = str(message.id)
|
||||||
|
|
||||||
# init queue manager
|
# init queue manager
|
||||||
queue_manager = MessageBasedAppQueueManager(
|
queue_manager = MessageBasedAppQueueManager(
|
||||||
task_id=application_generate_entity.task_id,
|
task_id=application_generate_entity.task_id,
|
||||||
user_id=application_generate_entity.user_id,
|
user_id=application_generate_entity.user_id,
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
conversation_id=conversation.id,
|
conversation_id=conversation_id,
|
||||||
app_mode=conversation.mode,
|
app_mode=conversation.mode,
|
||||||
message_id=message.id,
|
message_id=message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# new thread with request context
|
# new thread with request context
|
||||||
@ -330,7 +339,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
flask_app=current_app._get_current_object(), # type: ignore
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
message_id=message.id,
|
message_id=message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_thread = threading.Thread(target=worker_with_context)
|
worker_thread = threading.Thread(target=worker_with_context)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator, Iterator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
@ -54,7 +54,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_full_response(
|
def convert_stream_full_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream full response.
|
Convert stream full response.
|
||||||
@ -84,7 +84,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_simple_response(
|
def convert_stream_simple_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream simple response.
|
Convert stream simple response.
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator, Iterator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_full_response(
|
def convert_stream_full_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream full response.
|
Convert stream full response.
|
||||||
@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_simple_response(
|
def convert_stream_simple_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream simple response.
|
Convert stream simple response.
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator, Iterator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_full_response(
|
def convert_stream_full_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream full response.
|
Convert stream full response.
|
||||||
@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_stream_simple_response(
|
def convert_stream_simple_response(
|
||||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
cls, stream_response: Iterator[AppStreamResponse]
|
||||||
) -> Generator[dict | str, None, None]:
|
) -> Generator[dict | str, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert stream simple response.
|
Convert stream simple response.
|
||||||
|
|||||||
@ -1,13 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import Protocol, TypeAlias
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
InvokeFrom,
|
||||||
|
UserFrom,
|
||||||
|
build_dify_run_context,
|
||||||
|
)
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
AppQueueEvent,
|
AppQueueEvent,
|
||||||
QueueAgentLogEvent,
|
QueueAgentLogEvent,
|
||||||
@ -36,7 +40,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
|||||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from dify_graph.entities import GraphInitParams
|
from dify_graph.entities import GraphInitParams
|
||||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||||
from dify_graph.graph import Graph
|
from dify_graph.graph import Graph
|
||||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||||
@ -75,6 +79,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
GraphConfigObject: TypeAlias = dict[str, object]
|
||||||
|
GraphConfigMapping: TypeAlias = Mapping[str, object]
|
||||||
|
|
||||||
|
|
||||||
|
class SingleNodeRunEntity(Protocol):
|
||||||
|
node_id: str
|
||||||
|
inputs: Mapping[str, object]
|
||||||
|
|
||||||
|
|
||||||
class WorkflowBasedAppRunner:
|
class WorkflowBasedAppRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -98,7 +110,7 @@ class WorkflowBasedAppRunner:
|
|||||||
|
|
||||||
def _init_graph(
|
def _init_graph(
|
||||||
self,
|
self,
|
||||||
graph_config: Mapping[str, Any],
|
graph_config: GraphConfigMapping,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
user_from: UserFrom,
|
user_from: UserFrom,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
@ -154,8 +166,8 @@ class WorkflowBasedAppRunner:
|
|||||||
def _prepare_single_node_execution(
|
def _prepare_single_node_execution(
|
||||||
self,
|
self,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
single_iteration_run: Any | None = None,
|
single_iteration_run: SingleNodeRunEntity | None = None,
|
||||||
single_loop_run: Any | None = None,
|
single_loop_run: SingleNodeRunEntity | None = None,
|
||||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||||
"""
|
"""
|
||||||
Prepare graph, variable pool, and runtime state for single node execution
|
Prepare graph, variable pool, and runtime state for single node execution
|
||||||
@ -208,11 +220,88 @@ class WorkflowBasedAppRunner:
|
|||||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||||
return graph, variable_pool, graph_runtime_state
|
return graph, variable_pool, graph_runtime_state
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
|
||||||
|
nodes = graph_config.get("nodes")
|
||||||
|
edges = graph_config.get("edges")
|
||||||
|
if not isinstance(nodes, list):
|
||||||
|
raise ValueError("nodes in workflow graph must be a list")
|
||||||
|
if not isinstance(edges, list):
|
||||||
|
raise ValueError("edges in workflow graph must be a list")
|
||||||
|
|
||||||
|
validated_nodes: list[GraphConfigMapping] = []
|
||||||
|
for node in nodes:
|
||||||
|
if not isinstance(node, Mapping):
|
||||||
|
raise ValueError("nodes in workflow graph must be mappings")
|
||||||
|
validated_nodes.append(node)
|
||||||
|
|
||||||
|
validated_edges: list[GraphConfigMapping] = []
|
||||||
|
for edge in edges:
|
||||||
|
if not isinstance(edge, Mapping):
|
||||||
|
raise ValueError("edges in workflow graph must be mappings")
|
||||||
|
validated_edges.append(edge)
|
||||||
|
|
||||||
|
return validated_nodes, validated_edges
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
|
||||||
|
if node_config is None:
|
||||||
|
return None
|
||||||
|
node_data = node_config.get("data")
|
||||||
|
if not isinstance(node_data, Mapping):
|
||||||
|
return None
|
||||||
|
start_node_id = node_data.get("start_node_id")
|
||||||
|
return start_node_id if isinstance(start_node_id, str) else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_single_node_graph_config(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
graph_config: GraphConfigMapping,
|
||||||
|
node_id: str,
|
||||||
|
node_type_filter_key: str,
|
||||||
|
) -> tuple[GraphConfigObject, NodeConfigDict]:
|
||||||
|
node_configs, edge_configs = cls._get_graph_items(graph_config)
|
||||||
|
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
|
||||||
|
start_node_id = cls._extract_start_node_id(main_node_config)
|
||||||
|
|
||||||
|
filtered_node_configs = [
|
||||||
|
dict(node)
|
||||||
|
for node in node_configs
|
||||||
|
if node.get("id") == node_id
|
||||||
|
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
|
||||||
|
or (start_node_id and node.get("id") == start_node_id)
|
||||||
|
]
|
||||||
|
if not filtered_node_configs:
|
||||||
|
raise ValueError(f"node id {node_id} not found in workflow graph")
|
||||||
|
|
||||||
|
filtered_node_ids = {
|
||||||
|
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
|
||||||
|
}
|
||||||
|
filtered_edge_configs = [
|
||||||
|
dict(edge)
|
||||||
|
for edge in edge_configs
|
||||||
|
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
|
||||||
|
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
|
||||||
|
]
|
||||||
|
|
||||||
|
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
|
||||||
|
if target_node_config is None:
|
||||||
|
raise ValueError(f"node id {node_id} not found in workflow graph")
|
||||||
|
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
"nodes": filtered_node_configs,
|
||||||
|
"edges": filtered_edge_configs,
|
||||||
|
},
|
||||||
|
NodeConfigDictAdapter.validate_python(target_node_config),
|
||||||
|
)
|
||||||
|
|
||||||
def _get_graph_and_variable_pool_for_single_node_run(
|
def _get_graph_and_variable_pool_for_single_node_run(
|
||||||
self,
|
self,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
user_inputs: dict[str, Any],
|
user_inputs: Mapping[str, object],
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||||
@ -236,41 +325,14 @@ class WorkflowBasedAppRunner:
|
|||||||
if not graph_config:
|
if not graph_config:
|
||||||
raise ValueError("workflow graph not found")
|
raise ValueError("workflow graph not found")
|
||||||
|
|
||||||
graph_config = cast(dict[str, Any], graph_config)
|
|
||||||
|
|
||||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||||
raise ValueError("nodes or edges not found in workflow graph")
|
raise ValueError("nodes or edges not found in workflow graph")
|
||||||
|
|
||||||
if not isinstance(graph_config.get("nodes"), list):
|
graph_config, target_node_config = self._build_single_node_graph_config(
|
||||||
raise ValueError("nodes in workflow graph must be a list")
|
graph_config=graph_config,
|
||||||
|
node_id=node_id,
|
||||||
if not isinstance(graph_config.get("edges"), list):
|
node_type_filter_key=node_type_filter_key,
|
||||||
raise ValueError("edges in workflow graph must be a list")
|
)
|
||||||
|
|
||||||
# filter nodes only in the specified node type (iteration or loop)
|
|
||||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
|
||||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
|
||||||
node_configs = [
|
|
||||||
node
|
|
||||||
for node in graph_config.get("nodes", [])
|
|
||||||
if node.get("id") == node_id
|
|
||||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
|
||||||
or (start_node_id and node.get("id") == start_node_id)
|
|
||||||
]
|
|
||||||
|
|
||||||
graph_config["nodes"] = node_configs
|
|
||||||
|
|
||||||
node_ids = [node.get("id") for node in node_configs]
|
|
||||||
|
|
||||||
# filter edges only in the specified node type
|
|
||||||
edge_configs = [
|
|
||||||
edge
|
|
||||||
for edge in graph_config.get("edges", [])
|
|
||||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
|
||||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
|
||||||
]
|
|
||||||
|
|
||||||
graph_config["edges"] = edge_configs
|
|
||||||
|
|
||||||
# Create required parameters for Graph.init
|
# Create required parameters for Graph.init
|
||||||
graph_init_params = GraphInitParams(
|
graph_init_params = GraphInitParams(
|
||||||
@ -299,18 +361,6 @@ class WorkflowBasedAppRunner:
|
|||||||
if not graph:
|
if not graph:
|
||||||
raise ValueError("graph not found in workflow")
|
raise ValueError("graph not found in workflow")
|
||||||
|
|
||||||
# fetch node config from node id
|
|
||||||
target_node_config = None
|
|
||||||
for node in node_configs:
|
|
||||||
if node.get("id") == node_id:
|
|
||||||
target_node_config = node
|
|
||||||
break
|
|
||||||
|
|
||||||
if not target_node_config:
|
|
||||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
|
||||||
|
|
||||||
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
|
|
||||||
|
|
||||||
# Get node class
|
# Get node class
|
||||||
node_type = target_node_config["data"].type
|
node_type = target_node_config["data"].type
|
||||||
node_version = str(target_node_config["data"].version)
|
node_version = str(target_node_config["data"].version)
|
||||||
|
|||||||
@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
node_id: str
|
node_id: str
|
||||||
inputs: Mapping
|
inputs: Mapping[str, object]
|
||||||
|
|
||||||
single_iteration_run: SingleIterationRunEntity | None = None
|
single_iteration_run: SingleIterationRunEntity | None = None
|
||||||
|
|
||||||
@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
node_id: str
|
node_id: str
|
||||||
inputs: Mapping
|
inputs: Mapping[str, object]
|
||||||
|
|
||||||
single_loop_run: SingleLoopRunEntity | None = None
|
single_loop_run: SingleLoopRunEntity | None = None
|
||||||
|
|
||||||
@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
node_id: str
|
node_id: str
|
||||||
inputs: dict
|
inputs: Mapping[str, object]
|
||||||
|
|
||||||
single_iteration_run: SingleIterationRunEntity | None = None
|
single_iteration_run: SingleIterationRunEntity | None = None
|
||||||
|
|
||||||
@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
node_id: str
|
node_id: str
|
||||||
inputs: dict
|
inputs: Mapping[str, object]
|
||||||
|
|
||||||
single_loop_run: SingleLoopRunEntity | None = None
|
single_loop_run: SingleLoopRunEntity | None = None
|
||||||
|
|
||||||
|
|||||||
@ -1045,9 +1045,10 @@ class ToolManager:
|
|||||||
continue
|
continue
|
||||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||||
if tool_input.type == "variable":
|
if tool_input.type == "variable":
|
||||||
variable = variable_pool.get(tool_input.value)
|
variable_selector = tool_input.require_variable_selector()
|
||||||
|
variable = variable_pool.get(variable_selector)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
raise ToolParameterError(f"Variable {variable_selector} does not exist")
|
||||||
parameter_value = variable.value
|
parameter_value = variable.value
|
||||||
elif tool_input.type == "constant":
|
elif tool_input.type == "constant":
|
||||||
parameter_value = tool_input.value
|
parameter_value = tool_input.value
|
||||||
|
|||||||
@ -1,13 +1,24 @@
|
|||||||
from enum import IntEnum, StrEnum, auto
|
from __future__ import annotations
|
||||||
from typing import Any, Literal, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from enum import IntEnum, StrEnum, auto
|
||||||
|
from typing import Literal, TypeAlias
|
||||||
|
|
||||||
|
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||||
|
from pydantic_core.core_schema import ValidationInfo
|
||||||
|
|
||||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
from core.tools.entities.tool_entities import ToolSelector
|
from core.tools.entities.tool_entities import ToolSelector
|
||||||
from dify_graph.entities.base_node_data import BaseNodeData
|
from dify_graph.entities.base_node_data import BaseNodeData
|
||||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||||
|
|
||||||
|
AgentInputConstantValue: TypeAlias = (
|
||||||
|
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
|
||||||
|
)
|
||||||
|
VariableSelector: TypeAlias = list[str]
|
||||||
|
|
||||||
|
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
|
||||||
|
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||||
|
|
||||||
|
|
||||||
class AgentNodeData(BaseNodeData):
|
class AgentNodeData(BaseNodeData):
|
||||||
type: NodeType = BuiltinNodeTypes.AGENT
|
type: NodeType = BuiltinNodeTypes.AGENT
|
||||||
@ -21,8 +32,20 @@ class AgentNodeData(BaseNodeData):
|
|||||||
tool_node_version: str | None = None
|
tool_node_version: str | None = None
|
||||||
|
|
||||||
class AgentInput(BaseModel):
|
class AgentInput(BaseModel):
|
||||||
value: Union[list[str], list[ToolSelector], Any]
|
|
||||||
type: Literal["mixed", "variable", "constant"]
|
type: Literal["mixed", "variable", "constant"]
|
||||||
|
value: AgentInputConstantValue | VariableSelector
|
||||||
|
|
||||||
|
@field_validator("value", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_value(
|
||||||
|
cls, value: object, validation_info: ValidationInfo
|
||||||
|
) -> AgentInputConstantValue | VariableSelector:
|
||||||
|
input_type = validation_info.data.get("type")
|
||||||
|
if input_type == "variable":
|
||||||
|
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||||
|
if input_type in {"mixed", "constant"}:
|
||||||
|
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
|
||||||
|
raise ValueError(f"Unknown agent input type: {input_type}")
|
||||||
|
|
||||||
agent_parameters: dict[str, AgentInput]
|
agent_parameters: dict[str, AgentInput]
|
||||||
|
|
||||||
|
|||||||
@ -1,16 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import TypeAlias
|
||||||
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from pydantic import ValidationError
|
from pydantic import TypeAdapter, ValidationError
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.agent.plugin_entities import AgentStrategyParameter
|
from core.agent.plugin_entities import AgentStrategyParameter
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.plugin.entities.request import InvokeCredentials
|
from core.plugin.entities.request import InvokeCredentials
|
||||||
@ -28,6 +29,14 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
|
|||||||
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
|
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
|
||||||
from .strategy_protocols import ResolvedAgentStrategy
|
from .strategy_protocols import ResolvedAgentStrategy
|
||||||
|
|
||||||
|
JsonObject: TypeAlias = dict[str, object]
|
||||||
|
JsonObjectList: TypeAlias = list[JsonObject]
|
||||||
|
VariableSelector: TypeAlias = list[str]
|
||||||
|
|
||||||
|
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||||
|
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||||
|
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
|
||||||
|
|
||||||
|
|
||||||
class AgentRuntimeSupport:
|
class AgentRuntimeSupport:
|
||||||
def build_parameters(
|
def build_parameters(
|
||||||
@ -39,12 +48,12 @@ class AgentRuntimeSupport:
|
|||||||
strategy: ResolvedAgentStrategy,
|
strategy: ResolvedAgentStrategy,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
invoke_from: Any,
|
invoke_from: InvokeFrom,
|
||||||
for_log: bool = False,
|
for_log: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||||
|
|
||||||
result: dict[str, Any] = {}
|
result: dict[str, object] = {}
|
||||||
for parameter_name in node_data.agent_parameters:
|
for parameter_name in node_data.agent_parameters:
|
||||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||||
if not parameter:
|
if not parameter:
|
||||||
@ -54,9 +63,10 @@ class AgentRuntimeSupport:
|
|||||||
agent_input = node_data.agent_parameters[parameter_name]
|
agent_input = node_data.agent_parameters[parameter_name]
|
||||||
match agent_input.type:
|
match agent_input.type:
|
||||||
case "variable":
|
case "variable":
|
||||||
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
|
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
|
||||||
|
variable = variable_pool.get(variable_selector)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
raise AgentVariableNotFoundError(str(variable_selector))
|
||||||
parameter_value = variable.value
|
parameter_value = variable.value
|
||||||
case "mixed" | "constant":
|
case "mixed" | "constant":
|
||||||
try:
|
try:
|
||||||
@ -79,60 +89,38 @@ class AgentRuntimeSupport:
|
|||||||
|
|
||||||
value = parameter_value
|
value = parameter_value
|
||||||
if parameter.type == "array[tools]":
|
if parameter.type == "array[tools]":
|
||||||
value = cast(list[dict[str, Any]], value)
|
tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
|
||||||
value = [tool for tool in value if tool.get("enabled", False)]
|
value = self._normalize_tool_payloads(
|
||||||
value = self._filter_mcp_type_tool(strategy, value)
|
strategy=strategy,
|
||||||
for tool in value:
|
tools=tool_payloads,
|
||||||
if "schemas" in tool:
|
variable_pool=variable_pool,
|
||||||
tool.pop("schemas")
|
)
|
||||||
parameters = tool.get("parameters", {})
|
|
||||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
|
||||||
params = {}
|
|
||||||
for key, param in parameters.items():
|
|
||||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
|
||||||
ParamsAutoGenerated.CLOSE,
|
|
||||||
0,
|
|
||||||
):
|
|
||||||
value_param = param.get("value", {})
|
|
||||||
if value_param and value_param.get("type", "") == "variable":
|
|
||||||
variable_selector = value_param.get("value")
|
|
||||||
if not variable_selector:
|
|
||||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
|
||||||
|
|
||||||
variable = variable_pool.get(variable_selector)
|
|
||||||
if variable is None:
|
|
||||||
raise AgentVariableNotFoundError(str(variable_selector))
|
|
||||||
|
|
||||||
params[key] = variable.value
|
|
||||||
else:
|
|
||||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
|
||||||
else:
|
|
||||||
params[key] = None
|
|
||||||
parameters = params
|
|
||||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
|
||||||
tool["parameters"] = parameters
|
|
||||||
|
|
||||||
if not for_log:
|
if not for_log:
|
||||||
if parameter.type == "array[tools]":
|
if parameter.type == "array[tools]":
|
||||||
value = cast(list[dict[str, Any]], value)
|
value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
|
||||||
tool_value = []
|
tool_value = []
|
||||||
for tool in value:
|
for tool in value:
|
||||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
provider_type = self._coerce_tool_provider_type(tool.get("type"))
|
||||||
setting_params = tool.get("settings", {})
|
setting_params = self._coerce_json_object(tool.get("settings")) or {}
|
||||||
parameters = tool.get("parameters", {})
|
parameters = self._coerce_json_object(tool.get("parameters")) or {}
|
||||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||||
|
|
||||||
parameters = {**parameters, **setting_params}
|
parameters = {**parameters, **setting_params}
|
||||||
|
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
|
||||||
|
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
|
||||||
|
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
|
||||||
|
credential_id = self._coerce_optional_string(tool.get("credential_id"))
|
||||||
entity = AgentToolEntity(
|
entity = AgentToolEntity(
|
||||||
provider_id=tool.get("provider_name", ""),
|
provider_id=provider_id,
|
||||||
provider_type=provider_type,
|
provider_type=provider_type,
|
||||||
tool_name=tool.get("tool_name", ""),
|
tool_name=tool_name,
|
||||||
tool_parameters=parameters,
|
tool_parameters=parameters,
|
||||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
plugin_unique_identifier=plugin_unique_identifier,
|
||||||
credential_id=tool.get("credential_id", None),
|
credential_id=credential_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
extra = tool.get("extra", {})
|
extra = self._coerce_json_object(tool.get("extra")) or {}
|
||||||
|
|
||||||
runtime_variable_pool: VariablePool | None = None
|
runtime_variable_pool: VariablePool | None = None
|
||||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||||
@ -145,8 +133,9 @@ class AgentRuntimeSupport:
|
|||||||
runtime_variable_pool,
|
runtime_variable_pool,
|
||||||
)
|
)
|
||||||
if tool_runtime.entity.description:
|
if tool_runtime.entity.description:
|
||||||
|
description_override = self._coerce_optional_string(extra.get("description"))
|
||||||
tool_runtime.entity.description.llm = (
|
tool_runtime.entity.description.llm = (
|
||||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
description_override or tool_runtime.entity.description.llm
|
||||||
)
|
)
|
||||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||||
tool_runtime_params.form = (
|
tool_runtime_params.form = (
|
||||||
@ -167,13 +156,13 @@ class AgentRuntimeSupport:
|
|||||||
{
|
{
|
||||||
**tool_runtime.entity.model_dump(mode="json"),
|
**tool_runtime.entity.model_dump(mode="json"),
|
||||||
"runtime_parameters": runtime_parameters,
|
"runtime_parameters": runtime_parameters,
|
||||||
"credential_id": tool.get("credential_id", None),
|
"credential_id": credential_id,
|
||||||
"provider_type": provider_type.value,
|
"provider_type": provider_type.value,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
value = tool_value
|
value = tool_value
|
||||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||||
value = cast(dict[str, Any], value)
|
value = _JSON_OBJECT_ADAPTER.validate_python(value)
|
||||||
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
|
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
|
||||||
history_prompt_messages = []
|
history_prompt_messages = []
|
||||||
if node_data.memory:
|
if node_data.memory:
|
||||||
@ -199,17 +188,27 @@ class AgentRuntimeSupport:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
|
def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
|
||||||
credentials = InvokeCredentials()
|
credentials = InvokeCredentials()
|
||||||
credentials.tool_credentials = {}
|
credentials.tool_credentials = {}
|
||||||
for tool in parameters.get("tools", []):
|
tools = parameters.get("tools")
|
||||||
|
if not isinstance(tools, list):
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
for raw_tool in tools:
|
||||||
|
tool = self._coerce_json_object(raw_tool)
|
||||||
|
if tool is None:
|
||||||
|
continue
|
||||||
if not tool.get("credential_id"):
|
if not tool.get("credential_id"):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
continue
|
continue
|
||||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
credential_id = self._coerce_optional_string(tool.get("credential_id"))
|
||||||
|
if credential_id is None:
|
||||||
|
continue
|
||||||
|
credentials.tool_credentials[identity.provider] = credential_id
|
||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
def fetch_memory(
|
def fetch_memory(
|
||||||
@ -232,14 +231,14 @@ class AgentRuntimeSupport:
|
|||||||
|
|
||||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||||
|
|
||||||
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||||
provider_manager = ProviderManager()
|
provider_manager = ProviderManager()
|
||||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=value.get("provider", ""),
|
provider=str(value.get("provider", "")),
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
)
|
)
|
||||||
model_name = value.get("model", "")
|
model_name = str(value.get("model", ""))
|
||||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||||
model_type=ModelType.LLM,
|
model_type=ModelType.LLM,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@ -249,7 +248,7 @@ class AgentRuntimeSupport:
|
|||||||
model_instance = ModelManager().get_model_instance(
|
model_instance = ModelManager().get_model_instance(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
model_type=ModelType(value.get("model_type", "")),
|
model_type=ModelType(str(value.get("model_type", ""))),
|
||||||
model=model_name,
|
model=model_name,
|
||||||
)
|
)
|
||||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||||
@ -268,9 +267,88 @@ class AgentRuntimeSupport:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _filter_mcp_type_tool(
|
def _filter_mcp_type_tool(
|
||||||
strategy: ResolvedAgentStrategy,
|
strategy: ResolvedAgentStrategy,
|
||||||
tools: list[dict[str, Any]],
|
tools: JsonObjectList,
|
||||||
) -> list[dict[str, Any]]:
|
) -> JsonObjectList:
|
||||||
meta_version = strategy.meta_version
|
meta_version = strategy.meta_version
|
||||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||||
return tools
|
return tools
|
||||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||||
|
|
||||||
|
def _normalize_tool_payloads(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
strategy: ResolvedAgentStrategy,
|
||||||
|
tools: JsonObjectList,
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
) -> JsonObjectList:
|
||||||
|
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
|
||||||
|
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
|
||||||
|
for tool in normalized_tools:
|
||||||
|
tool.pop("schemas", None)
|
||||||
|
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
|
||||||
|
tool["settings"] = self._resolve_tool_settings(tool)
|
||||||
|
return normalized_tools
|
||||||
|
|
||||||
|
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
|
||||||
|
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
|
||||||
|
if parameter_configs is None:
|
||||||
|
raw_parameters = self._coerce_json_object(tool.get("parameters"))
|
||||||
|
return raw_parameters or {}
|
||||||
|
|
||||||
|
resolved_parameters: JsonObject = {}
|
||||||
|
for key, parameter_config in parameter_configs.items():
|
||||||
|
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
|
||||||
|
value_param = self._coerce_json_object(parameter_config.get("value"))
|
||||||
|
if value_param and value_param.get("type") == "variable":
|
||||||
|
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
|
||||||
|
variable = variable_pool.get(variable_selector)
|
||||||
|
if variable is None:
|
||||||
|
raise AgentVariableNotFoundError(str(variable_selector))
|
||||||
|
resolved_parameters[key] = variable.value
|
||||||
|
else:
|
||||||
|
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
|
||||||
|
else:
|
||||||
|
resolved_parameters[key] = None
|
||||||
|
|
||||||
|
return resolved_parameters
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
|
||||||
|
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
|
||||||
|
if settings is None:
|
||||||
|
return {}
|
||||||
|
return {key: setting.get("value") for key, setting in settings.items()}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_json_object(value: object) -> JsonObject | None:
|
||||||
|
try:
|
||||||
|
return _JSON_OBJECT_ADAPTER.validate_python(value)
|
||||||
|
except ValidationError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_optional_string(value: object) -> str | None:
|
||||||
|
return value if isinstance(value, str) else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
|
||||||
|
if isinstance(value, ToolProviderType):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
return ToolProviderType(value)
|
||||||
|
return ToolProviderType.BUILT_IN
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
coerced: dict[str, JsonObject] = {}
|
||||||
|
for key, item in value.items():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
return None
|
||||||
|
json_object = cls._coerce_json_object(item)
|
||||||
|
if json_object is None:
|
||||||
|
return None
|
||||||
|
coerced[key] = json_object
|
||||||
|
return coerced
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, TypeAlias, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.exc import GenerateTaskStoppedError
|
from core.app.apps.exc import GenerateTaskStoppedError
|
||||||
@ -32,6 +32,13 @@ from models.workflow import Workflow
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SpecialValueScalar: TypeAlias = str | int | float | bool | None
|
||||||
|
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
|
||||||
|
SerializedSpecialValue: TypeAlias = (
|
||||||
|
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
|
||||||
|
)
|
||||||
|
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
|
||||||
|
|
||||||
|
|
||||||
class _WorkflowChildEngineBuilder:
|
class _WorkflowChildEngineBuilder:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -276,10 +283,10 @@ class WorkflowEntry:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_single_node_graph(
|
def _create_single_node_graph(
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_data: dict[str, Any],
|
node_data: Mapping[str, object],
|
||||||
node_width: int = 114,
|
node_width: int = 114,
|
||||||
node_height: int = 514,
|
node_height: int = 514,
|
||||||
) -> dict[str, Any]:
|
) -> SingleNodeGraphConfig:
|
||||||
"""
|
"""
|
||||||
Create a minimal graph structure for testing a single node in isolation.
|
Create a minimal graph structure for testing a single node in isolation.
|
||||||
|
|
||||||
@ -289,14 +296,14 @@ class WorkflowEntry:
|
|||||||
:param node_height: height for UI layout (default: 100)
|
:param node_height: height for UI layout (default: 100)
|
||||||
:return: graph dictionary with start node and target node
|
:return: graph dictionary with start node and target node
|
||||||
"""
|
"""
|
||||||
node_config = {
|
node_config: dict[str, object] = {
|
||||||
"id": node_id,
|
"id": node_id,
|
||||||
"width": node_width,
|
"width": node_width,
|
||||||
"height": node_height,
|
"height": node_height,
|
||||||
"type": "custom",
|
"type": "custom",
|
||||||
"data": node_data,
|
"data": dict(node_data),
|
||||||
}
|
}
|
||||||
start_node_config = {
|
start_node_config: dict[str, object] = {
|
||||||
"id": "start",
|
"id": "start",
|
||||||
"width": node_width,
|
"width": node_width,
|
||||||
"height": node_height,
|
"height": node_height,
|
||||||
@ -321,7 +328,12 @@ class WorkflowEntry:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def run_free_node(
|
def run_free_node(
|
||||||
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
cls,
|
||||||
|
node_data: Mapping[str, object],
|
||||||
|
node_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
user_inputs: Mapping[str, object],
|
||||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||||
"""
|
"""
|
||||||
Run free node
|
Run free node
|
||||||
@ -339,6 +351,8 @@ class WorkflowEntry:
|
|||||||
graph_dict = cls._create_single_node_graph(node_id, node_data)
|
graph_dict = cls._create_single_node_graph(node_id, node_data)
|
||||||
|
|
||||||
node_type = node_data.get("type", "")
|
node_type = node_data.get("type", "")
|
||||||
|
if not isinstance(node_type, str):
|
||||||
|
raise ValueError("Node type must be a string")
|
||||||
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
||||||
raise ValueError(f"Node type {node_type} not supported")
|
raise ValueError(f"Node type {node_type} not supported")
|
||||||
|
|
||||||
@ -369,7 +383,7 @@ class WorkflowEntry:
|
|||||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
|
|
||||||
# init workflow run state
|
# init workflow run state
|
||||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
|
||||||
node_factory = DifyNodeFactory(
|
node_factory = DifyNodeFactory(
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
@ -405,30 +419,34 @@ class WorkflowEntry:
|
|||||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
|
||||||
# NOTE(QuantumGhost): Avoid using this function in new code.
|
# NOTE(QuantumGhost): Avoid using this function in new code.
|
||||||
# Keep values structured as long as possible and only convert to dict
|
# Keep values structured as long as possible and only convert to dict
|
||||||
# immediately before serialization (e.g., JSON serialization) to maintain
|
# immediately before serialization (e.g., JSON serialization) to maintain
|
||||||
# data integrity and type information.
|
# data integrity and type information.
|
||||||
result = WorkflowEntry._handle_special_values(value)
|
result = WorkflowEntry._handle_special_values(value)
|
||||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
if result is None:
|
||||||
|
return None
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result
|
||||||
|
raise TypeError("handle_special_values expects a mapping input")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _handle_special_values(value: Any):
|
def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
if isinstance(value, dict):
|
if isinstance(value, Mapping):
|
||||||
res = {}
|
res: dict[str, SerializedSpecialValue] = {}
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
res[k] = WorkflowEntry._handle_special_values(v)
|
res[k] = WorkflowEntry._handle_special_values(v)
|
||||||
return res
|
return res
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
res_list = []
|
res_list: list[SerializedSpecialValue] = []
|
||||||
for item in value:
|
for item in value:
|
||||||
res_list.append(WorkflowEntry._handle_special_values(item))
|
res_list.append(WorkflowEntry._handle_special_values(item))
|
||||||
return res_list
|
return res_list
|
||||||
if isinstance(value, File):
|
if isinstance(value, File):
|
||||||
return value.to_dict()
|
return dict(value.to_dict())
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -112,6 +112,8 @@ def _get_encoded_string(f: File, /) -> str:
|
|||||||
data = _download_file_content(f.storage_key)
|
data = _download_file_content(f.storage_key)
|
||||||
case FileTransferMethod.DATASOURCE_FILE:
|
case FileTransferMethod.DATASOURCE_FILE:
|
||||||
data = _download_file_content(f.storage_key)
|
data = _download_file_content(f.storage_key)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||||
|
|
||||||
return base64.b64encode(data).decode("utf-8")
|
return base64.b64encode(data).decode("utf-8")
|
||||||
|
|
||||||
|
|||||||
@ -133,6 +133,8 @@ class ExecutionLimitsLayer(GraphEngineLayer):
|
|||||||
elif limit_type == LimitType.TIME_LIMIT:
|
elif limit_type == LimitType.TIME_LIMIT:
|
||||||
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
||||||
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
self.logger.warning("Execution limit exceeded: %s", reason)
|
self.logger.warning("Execution limit exceeded: %s", reason)
|
||||||
|
|
||||||
|
|||||||
@ -336,12 +336,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
|
|
||||||
def _restore_execution_id_from_runtime_state(self) -> str | None:
|
def _restore_execution_id_from_runtime_state(self) -> str | None:
|
||||||
graph_execution = self.graph_runtime_state.graph_execution
|
graph_execution = self.graph_runtime_state.graph_execution
|
||||||
try:
|
node_executions = graph_execution.node_executions
|
||||||
node_executions = graph_execution.node_executions
|
|
||||||
except AttributeError:
|
|
||||||
return None
|
|
||||||
if not isinstance(node_executions, dict):
|
|
||||||
return None
|
|
||||||
node_execution = node_executions.get(self._node_id)
|
node_execution = node_executions.get(self._node_id)
|
||||||
if node_execution is None:
|
if node_execution is None:
|
||||||
return None
|
return None
|
||||||
@ -395,8 +390,7 @@ class Node(Generic[NodeDataT]):
|
|||||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||||
yield self._dispatch(event)
|
yield self._dispatch(event)
|
||||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||||
event.id = self.execution_id
|
yield event.model_copy(update={"id": self.execution_id})
|
||||||
yield event
|
|
||||||
else:
|
else:
|
||||||
yield event
|
yield event
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -443,7 +443,10 @@ def _extract_text_from_docx(file_content: bytes) -> str:
|
|||||||
# Keep track of paragraph and table positions
|
# Keep track of paragraph and table positions
|
||||||
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
||||||
|
|
||||||
it = iter(doc.element.body)
|
doc_body = getattr(doc.element, "body", None)
|
||||||
|
if doc_body is None:
|
||||||
|
raise TextExtractionError("DOCX body not found")
|
||||||
|
it = iter(doc_body)
|
||||||
part = next(it, None)
|
part = next(it, None)
|
||||||
i = 0
|
i = 0
|
||||||
while part is not None:
|
while part is not None:
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Literal
|
from typing import Literal, NotRequired
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||||
from dify_graph.entities.base_node_data import BaseNodeData
|
from dify_graph.entities.base_node_data import BaseNodeData
|
||||||
@ -10,11 +11,17 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
|||||||
from dify_graph.nodes.base.entities import VariableSelector
|
from dify_graph.nodes.base.entities import VariableSelector
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredOutputConfig(TypedDict):
|
||||||
|
schema: Mapping[str, object]
|
||||||
|
name: NotRequired[str]
|
||||||
|
description: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
provider: str
|
provider: str
|
||||||
name: str
|
name: str
|
||||||
mode: LLMMode
|
mode: LLMMode
|
||||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
completion_params: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ContextConfig(BaseModel):
|
class ContextConfig(BaseModel):
|
||||||
@ -33,7 +40,7 @@ class VisionConfig(BaseModel):
|
|||||||
|
|
||||||
@field_validator("configs", mode="before")
|
@field_validator("configs", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_none_configs(cls, v: Any):
|
def convert_none_configs(cls, v: object):
|
||||||
if v is None:
|
if v is None:
|
||||||
return VisionConfigOptions()
|
return VisionConfigOptions()
|
||||||
return v
|
return v
|
||||||
@ -44,7 +51,7 @@ class PromptConfig(BaseModel):
|
|||||||
|
|
||||||
@field_validator("jinja2_variables", mode="before")
|
@field_validator("jinja2_variables", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_none_jinja2_variables(cls, v: Any):
|
def convert_none_jinja2_variables(cls, v: object):
|
||||||
if v is None:
|
if v is None:
|
||||||
return []
|
return []
|
||||||
return v
|
return v
|
||||||
@ -67,7 +74,7 @@ class LLMNodeData(BaseNodeData):
|
|||||||
memory: MemoryConfig | None = None
|
memory: MemoryConfig | None = None
|
||||||
context: ContextConfig
|
context: ContextConfig
|
||||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||||
structured_output: Mapping[str, Any] | None = None
|
structured_output: StructuredOutputConfig | None = None
|
||||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||||
reasoning_format: Literal["separated", "tagged"] = Field(
|
reasoning_format: Literal["separated", "tagged"] = Field(
|
||||||
@ -90,11 +97,30 @@ class LLMNodeData(BaseNodeData):
|
|||||||
|
|
||||||
@field_validator("prompt_config", mode="before")
|
@field_validator("prompt_config", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_none_prompt_config(cls, v: Any):
|
def convert_none_prompt_config(cls, v: object):
|
||||||
if v is None:
|
if v is None:
|
||||||
return PromptConfig()
|
return PromptConfig()
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("structured_output", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_legacy_structured_output(cls, v: object) -> StructuredOutputConfig | None | object:
|
||||||
|
if not isinstance(v, Mapping):
|
||||||
|
return v
|
||||||
|
|
||||||
|
schema = v.get("schema")
|
||||||
|
if schema is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
normalized: StructuredOutputConfig = {"schema": schema}
|
||||||
|
name = v.get("name")
|
||||||
|
description = v.get("description")
|
||||||
|
if isinstance(name, str):
|
||||||
|
normalized["name"] = name
|
||||||
|
if isinstance(description, str):
|
||||||
|
normalized["description"] = description
|
||||||
|
return normalized
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def structured_output_enabled(self) -> bool:
|
def structured_output_enabled(self) -> bool:
|
||||||
return self.structured_output_switch_on and self.structured_output is not None
|
return self.structured_output_switch_on and self.structured_output is not None
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import time
|
|||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.llm_generator.output_parser.errors import OutputParserError
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
@ -74,6 +75,7 @@ from .entities import (
|
|||||||
LLMNodeChatModelMessage,
|
LLMNodeChatModelMessage,
|
||||||
LLMNodeCompletionModelPromptTemplate,
|
LLMNodeCompletionModelPromptTemplate,
|
||||||
LLMNodeData,
|
LLMNodeData,
|
||||||
|
StructuredOutputConfig,
|
||||||
)
|
)
|
||||||
from .exc import (
|
from .exc import (
|
||||||
InvalidContextStructureError,
|
InvalidContextStructureError,
|
||||||
@ -88,6 +90,7 @@ if TYPE_CHECKING:
|
|||||||
from dify_graph.runtime import GraphRuntimeState
|
from dify_graph.runtime import GraphRuntimeState
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
|
||||||
|
|
||||||
|
|
||||||
class LLMNode(Node[LLMNodeData]):
|
class LLMNode(Node[LLMNodeData]):
|
||||||
@ -358,7 +361,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
stop: Sequence[str] | None = None,
|
stop: Sequence[str] | None = None,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
structured_output_enabled: bool,
|
structured_output_enabled: bool,
|
||||||
structured_output: Mapping[str, Any] | None = None,
|
structured_output: StructuredOutputConfig | None = None,
|
||||||
file_saver: LLMFileSaver,
|
file_saver: LLMFileSaver,
|
||||||
file_outputs: list[File],
|
file_outputs: list[File],
|
||||||
node_id: str,
|
node_id: str,
|
||||||
@ -371,8 +374,10 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||||
|
|
||||||
if structured_output_enabled:
|
if structured_output_enabled:
|
||||||
|
if structured_output is None:
|
||||||
|
raise LLMNodeError("Please provide a valid structured output schema")
|
||||||
output_schema = LLMNode.fetch_structured_output_schema(
|
output_schema = LLMNode.fetch_structured_output_schema(
|
||||||
structured_output=structured_output or {},
|
structured_output=structured_output,
|
||||||
)
|
)
|
||||||
request_start_time = time.perf_counter()
|
request_start_time = time.perf_counter()
|
||||||
|
|
||||||
@ -924,6 +929,12 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
# Extract clean text and reasoning from <think> tags
|
# Extract clean text and reasoning from <think> tags
|
||||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||||
|
|
||||||
|
structured_output = (
|
||||||
|
dict(invoke_result.structured_output)
|
||||||
|
if isinstance(invoke_result, LLMResultWithStructuredOutput) and invoke_result.structured_output is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
event = ModelInvokeCompletedEvent(
|
event = ModelInvokeCompletedEvent(
|
||||||
# Use clean_text for separated mode, full_text for tagged mode
|
# Use clean_text for separated mode, full_text for tagged mode
|
||||||
text=clean_text if reasoning_format == "separated" else full_text,
|
text=clean_text if reasoning_format == "separated" else full_text,
|
||||||
@ -932,7 +943,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
# Reasoning content for workflow variables and downstream nodes
|
# Reasoning content for workflow variables and downstream nodes
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
# Pass structured output if enabled
|
# Pass structured output if enabled
|
||||||
structured_output=getattr(invoke_result, "structured_output", None),
|
structured_output=structured_output,
|
||||||
)
|
)
|
||||||
if request_latency is not None:
|
if request_latency is not None:
|
||||||
event.usage.latency = round(request_latency, 3)
|
event.usage.latency = round(request_latency, 3)
|
||||||
@ -966,27 +977,18 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def fetch_structured_output_schema(
|
def fetch_structured_output_schema(
|
||||||
*,
|
*,
|
||||||
structured_output: Mapping[str, Any],
|
structured_output: StructuredOutputConfig,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object]:
|
||||||
"""
|
"""
|
||||||
Fetch the structured output schema from the node data.
|
Fetch the structured output schema from the node data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Any]: The structured output schema
|
dict[str, object]: The structured output schema
|
||||||
"""
|
"""
|
||||||
if not structured_output:
|
schema = structured_output.get("schema")
|
||||||
|
if not schema:
|
||||||
raise LLMNodeError("Please provide a valid structured output schema")
|
raise LLMNodeError("Please provide a valid structured output schema")
|
||||||
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
|
return _JSON_OBJECT_ADAPTER.validate_python(schema)
|
||||||
if not structured_output_schema:
|
|
||||||
raise LLMNodeError("Please provide a valid structured output schema")
|
|
||||||
|
|
||||||
try:
|
|
||||||
schema = json.loads(structured_output_schema)
|
|
||||||
if not isinstance(schema, dict):
|
|
||||||
raise LLMNodeError("structured_output_schema must be a JSON object")
|
|
||||||
return schema
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
from enum import StrEnum
|
from __future__ import annotations
|
||||||
from typing import Annotated, Any, Literal
|
|
||||||
|
|
||||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
from enum import StrEnum
|
||||||
|
from typing import Annotated, Any, Literal, TypeAlias, cast
|
||||||
|
|
||||||
|
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
|
||||||
|
from pydantic_core.core_schema import ValidationInfo
|
||||||
|
|
||||||
from dify_graph.entities.base_node_data import BaseNodeData
|
from dify_graph.entities.base_node_data import BaseNodeData
|
||||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||||
@ -9,6 +12,12 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
|
|||||||
from dify_graph.utils.condition.entities import Condition
|
from dify_graph.utils.condition.entities import Condition
|
||||||
from dify_graph.variables.types import SegmentType
|
from dify_graph.variables.types import SegmentType
|
||||||
|
|
||||||
|
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||||
|
LoopValueMapping: TypeAlias = dict[str, LoopValue]
|
||||||
|
VariableSelector: TypeAlias = list[str]
|
||||||
|
|
||||||
|
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||||
|
|
||||||
_VALID_VAR_TYPE = frozenset(
|
_VALID_VAR_TYPE = frozenset(
|
||||||
[
|
[
|
||||||
SegmentType.STRING,
|
SegmentType.STRING,
|
||||||
@ -29,6 +38,36 @@ def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
|||||||
return seg_type
|
return seg_type
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_loop_value(value: object) -> LoopValue:
|
||||||
|
if value is None or isinstance(value, (str, int, float, bool)):
|
||||||
|
return cast(LoopValue, value)
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_validate_loop_value(item) for item in value]
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
normalized: dict[str, LoopValue] = {}
|
||||||
|
for key, item in value.items():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
raise TypeError("Loop values only support string object keys")
|
||||||
|
normalized[key] = _validate_loop_value(item)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise TypeError("Loop outputs must be an object")
|
||||||
|
|
||||||
|
normalized: LoopValueMapping = {}
|
||||||
|
for key, item in value.items():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
raise TypeError("Loop output keys must be strings")
|
||||||
|
normalized[key] = _validate_loop_value(item)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
class LoopVariableData(BaseModel):
|
class LoopVariableData(BaseModel):
|
||||||
"""
|
"""
|
||||||
Loop Variable Data.
|
Loop Variable Data.
|
||||||
@ -37,7 +76,29 @@ class LoopVariableData(BaseModel):
|
|||||||
label: str
|
label: str
|
||||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||||
value_type: Literal["variable", "constant"]
|
value_type: Literal["variable", "constant"]
|
||||||
value: Any | list[str] | None = None
|
value: LoopValue | VariableSelector | None = None
|
||||||
|
|
||||||
|
@field_validator("value", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
|
||||||
|
value_type = validation_info.data.get("value_type")
|
||||||
|
if value_type == "variable":
|
||||||
|
if value is None:
|
||||||
|
raise ValueError("Variable loop inputs require a selector")
|
||||||
|
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||||
|
if value_type == "constant":
|
||||||
|
return _validate_loop_value(value)
|
||||||
|
raise ValueError(f"Unknown loop variable value type: {value_type}")
|
||||||
|
|
||||||
|
def require_variable_selector(self) -> VariableSelector:
|
||||||
|
if self.value_type != "variable":
|
||||||
|
raise ValueError(f"Expected variable loop input, got {self.value_type}")
|
||||||
|
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||||
|
|
||||||
|
def require_constant_value(self) -> LoopValue:
|
||||||
|
if self.value_type != "constant":
|
||||||
|
raise ValueError(f"Expected constant loop input, got {self.value_type}")
|
||||||
|
return _validate_loop_value(self.value)
|
||||||
|
|
||||||
|
|
||||||
class LoopNodeData(BaseLoopNodeData):
|
class LoopNodeData(BaseLoopNodeData):
|
||||||
@ -46,14 +107,14 @@ class LoopNodeData(BaseLoopNodeData):
|
|||||||
break_conditions: list[Condition] # Conditions to break the loop
|
break_conditions: list[Condition] # Conditions to break the loop
|
||||||
logical_operator: Literal["and", "or"]
|
logical_operator: Literal["and", "or"]
|
||||||
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
outputs: LoopValueMapping = Field(default_factory=dict)
|
||||||
|
|
||||||
@field_validator("outputs", mode="before")
|
@field_validator("outputs", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_outputs(cls, v):
|
def validate_outputs(cls, value: object) -> LoopValueMapping:
|
||||||
if v is None:
|
if value is None:
|
||||||
return {}
|
return {}
|
||||||
return v
|
return _validate_loop_value_mapping(value)
|
||||||
|
|
||||||
|
|
||||||
class LoopStartNodeData(BaseNodeData):
|
class LoopStartNodeData(BaseNodeData):
|
||||||
@ -77,8 +138,8 @@ class LoopState(BaseLoopState):
|
|||||||
Loop State.
|
Loop State.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
outputs: list[Any] = Field(default_factory=list)
|
outputs: list[LoopValue] = Field(default_factory=list)
|
||||||
current_output: Any = None
|
current_output: LoopValue | None = None
|
||||||
|
|
||||||
class MetaData(BaseLoopState.MetaData):
|
class MetaData(BaseLoopState.MetaData):
|
||||||
"""
|
"""
|
||||||
@ -87,7 +148,7 @@ class LoopState(BaseLoopState):
|
|||||||
|
|
||||||
loop_length: int
|
loop_length: int
|
||||||
|
|
||||||
def get_last_output(self) -> Any:
|
def get_last_output(self) -> LoopValue | None:
|
||||||
"""
|
"""
|
||||||
Get last output.
|
Get last output.
|
||||||
"""
|
"""
|
||||||
@ -95,7 +156,7 @@ class LoopState(BaseLoopState):
|
|||||||
return self.outputs[-1]
|
return self.outputs[-1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_current_output(self) -> Any:
|
def get_current_output(self) -> LoopValue | None:
|
||||||
"""
|
"""
|
||||||
Get current output.
|
Get current output.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Literal, cast
|
||||||
|
|
||||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||||
from dify_graph.enums import (
|
from dify_graph.enums import (
|
||||||
@ -29,7 +29,7 @@ from dify_graph.node_events import (
|
|||||||
)
|
)
|
||||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||||
from dify_graph.nodes.base.node import Node
|
from dify_graph.nodes.base.node import Node
|
||||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
|
||||||
from dify_graph.utils.condition.processor import ConditionProcessor
|
from dify_graph.utils.condition.processor import ConditionProcessor
|
||||||
from dify_graph.variables import Segment, SegmentType
|
from dify_graph.variables import Segment, SegmentType
|
||||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||||
@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
break_conditions = self.node_data.break_conditions
|
break_conditions = self.node_data.break_conditions
|
||||||
logical_operator = self.node_data.logical_operator
|
logical_operator = self.node_data.logical_operator
|
||||||
|
|
||||||
inputs = {"loop_count": loop_count}
|
inputs: dict[str, object] = {"loop_count": loop_count}
|
||||||
|
|
||||||
if not self.node_data.start_node_id:
|
if not self.node_data.start_node_id:
|
||||||
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
||||||
@ -68,12 +68,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
root_node_id = self.node_data.start_node_id
|
root_node_id = self.node_data.start_node_id
|
||||||
|
|
||||||
# Initialize loop variables in the original variable pool
|
# Initialize loop variables in the original variable pool
|
||||||
loop_variable_selectors = {}
|
loop_variable_selectors: dict[str, list[str]] = {}
|
||||||
if self.node_data.loop_variables:
|
if self.node_data.loop_variables:
|
||||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
|
||||||
"variable": lambda var: (
|
"variable": lambda var: (
|
||||||
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None
|
self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
|
||||||
|
if var.value is not None
|
||||||
|
else None
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
for loop_variable in self.node_data.loop_variables:
|
for loop_variable in self.node_data.loop_variables:
|
||||||
@ -95,7 +97,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
condition_processor = ConditionProcessor()
|
condition_processor = ConditionProcessor()
|
||||||
|
|
||||||
loop_duration_map: dict[str, float] = {}
|
loop_duration_map: dict[str, float] = {}
|
||||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
|
||||||
loop_usage = LLMUsage.empty_usage()
|
loop_usage = LLMUsage.empty_usage()
|
||||||
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
|
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
|
||||||
|
|
||||||
@ -146,7 +148,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||||
|
|
||||||
# Collect loop variable values after iteration
|
# Collect loop variable values after iteration
|
||||||
single_loop_variable = {}
|
single_loop_variable: dict[str, LoopValue] = {}
|
||||||
for key, selector in loop_variable_selectors.items():
|
for key, selector in loop_variable_selectors.items():
|
||||||
segment = self.graph_runtime_state.variable_pool.get(selector)
|
segment = self.graph_runtime_state.variable_pool.get(selector)
|
||||||
single_loop_variable[key] = segment.value if segment else None
|
single_loop_variable[key] = segment.value if segment else None
|
||||||
@ -297,20 +299,29 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
graph_config: Mapping[str, Any],
|
graph_config: Mapping[str, object],
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_data: LoopNodeData,
|
node_data: LoopNodeData,
|
||||||
) -> Mapping[str, Sequence[str]]:
|
) -> Mapping[str, Sequence[str]]:
|
||||||
variable_mapping = {}
|
variable_mapping: dict[str, Sequence[str]] = {}
|
||||||
|
|
||||||
# Extract loop node IDs statically from graph_config
|
# Extract loop node IDs statically from graph_config
|
||||||
|
|
||||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||||
|
|
||||||
# Get node configs from graph_config
|
# Get node configs from graph_config
|
||||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
raw_nodes = graph_config.get("nodes")
|
||||||
|
node_configs: dict[str, Mapping[str, object]] = {}
|
||||||
|
if isinstance(raw_nodes, list):
|
||||||
|
for raw_node in raw_nodes:
|
||||||
|
if not isinstance(raw_node, dict):
|
||||||
|
continue
|
||||||
|
raw_node_id = raw_node.get("id")
|
||||||
|
if isinstance(raw_node_id, str):
|
||||||
|
node_configs[raw_node_id] = raw_node
|
||||||
for sub_node_id, sub_node_config in node_configs.items():
|
for sub_node_id, sub_node_config in node_configs.items():
|
||||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
sub_node_data = sub_node_config.get("data")
|
||||||
|
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# variable selector to variable mapping
|
# variable selector to variable mapping
|
||||||
@ -341,9 +352,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
|
|
||||||
for loop_variable in node_data.loop_variables or []:
|
for loop_variable in node_data.loop_variables or []:
|
||||||
if loop_variable.value_type == "variable":
|
if loop_variable.value_type == "variable":
|
||||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
|
||||||
# add loop variable to variable mapping
|
# add loop variable to variable mapping
|
||||||
selector = loop_variable.value
|
selector = loop_variable.require_variable_selector()
|
||||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||||
|
|
||||||
# remove variable out from loop
|
# remove variable out from loop
|
||||||
@ -352,7 +362,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Extract node IDs that belong to a specific loop from graph configuration.
|
Extract node IDs that belong to a specific loop from graph configuration.
|
||||||
|
|
||||||
@ -363,12 +373,19 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
:param loop_node_id: the ID of the loop node
|
:param loop_node_id: the ID of the loop node
|
||||||
:return: set of node IDs that belong to the loop
|
:return: set of node IDs that belong to the loop
|
||||||
"""
|
"""
|
||||||
loop_node_ids = set()
|
loop_node_ids: set[str] = set()
|
||||||
|
|
||||||
# Find all nodes that belong to this loop
|
# Find all nodes that belong to this loop
|
||||||
nodes = graph_config.get("nodes", [])
|
raw_nodes = graph_config.get("nodes")
|
||||||
for node in nodes:
|
if not isinstance(raw_nodes, list):
|
||||||
node_data = node.get("data", {})
|
return loop_node_ids
|
||||||
|
|
||||||
|
for node in raw_nodes:
|
||||||
|
if not isinstance(node, dict):
|
||||||
|
continue
|
||||||
|
node_data = node.get("data")
|
||||||
|
if not isinstance(node_data, dict):
|
||||||
|
continue
|
||||||
if node_data.get("loop_id") == loop_node_id:
|
if node_data.get("loop_id") == loop_node_id:
|
||||||
node_id = node.get("id")
|
node_id = node.get("id")
|
||||||
if node_id:
|
if node_id:
|
||||||
@ -377,7 +394,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
return loop_node_ids
|
return loop_node_ids
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
|
||||||
"""Get the appropriate segment type for a constant value."""
|
"""Get the appropriate segment type for a constant value."""
|
||||||
# TODO: Refactor for maintainability:
|
# TODO: Refactor for maintainability:
|
||||||
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
|
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
|
||||||
@ -389,11 +406,15 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||||||
SegmentType.ARRAY_OBJECT,
|
SegmentType.ARRAY_OBJECT,
|
||||||
SegmentType.ARRAY_STRING,
|
SegmentType.ARRAY_STRING,
|
||||||
]:
|
]:
|
||||||
if original_value and isinstance(original_value, str):
|
# New typed payloads may already provide native lists, while legacy
|
||||||
value = json.loads(original_value)
|
# configs still serialize array constants as JSON strings.
|
||||||
else:
|
if isinstance(original_value, str):
|
||||||
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
|
value = json.loads(original_value) if original_value else []
|
||||||
|
elif original_value is None:
|
||||||
|
# Preserve legacy behavior: treat missing/empty array constants as [].
|
||||||
value = []
|
value = []
|
||||||
|
else:
|
||||||
|
value = original_value
|
||||||
else:
|
else:
|
||||||
raise AssertionError("this statement should be unreachable.")
|
raise AssertionError("this statement should be unreachable.")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@ -6,6 +6,7 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
field_validator,
|
field_validator,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
from dify_graph.entities.base_node_data import BaseNodeData
|
from dify_graph.entities.base_node_data import BaseNodeData
|
||||||
@ -55,7 +56,7 @@ class ParameterConfig(BaseModel):
|
|||||||
|
|
||||||
@field_validator("name", mode="before")
|
@field_validator("name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_name(cls, value) -> str:
|
def validate_name(cls, value: object) -> str:
|
||||||
if not value:
|
if not value:
|
||||||
raise ValueError("Parameter name is required")
|
raise ValueError("Parameter name is required")
|
||||||
if value in {"__reason", "__is_success"}:
|
if value in {"__reason", "__is_success"}:
|
||||||
@ -79,6 +80,23 @@ class ParameterConfig(BaseModel):
|
|||||||
return element_type
|
return element_type
|
||||||
|
|
||||||
|
|
||||||
|
class JsonSchemaArrayItems(TypedDict):
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterJsonSchemaProperty(TypedDict, total=False):
|
||||||
|
description: str
|
||||||
|
type: str
|
||||||
|
items: JsonSchemaArrayItems
|
||||||
|
enum: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterJsonSchema(TypedDict):
|
||||||
|
type: Literal["object"]
|
||||||
|
properties: dict[str, ParameterJsonSchemaProperty]
|
||||||
|
required: list[str]
|
||||||
|
|
||||||
|
|
||||||
class ParameterExtractorNodeData(BaseNodeData):
|
class ParameterExtractorNodeData(BaseNodeData):
|
||||||
"""
|
"""
|
||||||
Parameter Extractor Node Data.
|
Parameter Extractor Node Data.
|
||||||
@ -95,19 +113,19 @@ class ParameterExtractorNodeData(BaseNodeData):
|
|||||||
|
|
||||||
@field_validator("reasoning_mode", mode="before")
|
@field_validator("reasoning_mode", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_reasoning_mode(cls, v) -> str:
|
def set_reasoning_mode(cls, v: object) -> str:
|
||||||
return v or "function_call"
|
return str(v) if v else "function_call"
|
||||||
|
|
||||||
def get_parameter_json_schema(self):
|
def get_parameter_json_schema(self) -> ParameterJsonSchema:
|
||||||
"""
|
"""
|
||||||
Get parameter json schema.
|
Get parameter json schema.
|
||||||
|
|
||||||
:return: parameter json schema
|
:return: parameter json schema
|
||||||
"""
|
"""
|
||||||
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
|
||||||
|
|
||||||
for parameter in self.parameters:
|
for parameter in self.parameters:
|
||||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
|
||||||
|
|
||||||
if parameter.type == SegmentType.STRING:
|
if parameter.type == SegmentType.STRING:
|
||||||
parameter_schema["type"] = "string"
|
parameter_schema["type"] = "string"
|
||||||
@ -118,7 +136,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
|||||||
raise AssertionError("element type should not be None.")
|
raise AssertionError("element type should not be None.")
|
||||||
parameter_schema["items"] = {"type": element_type.value}
|
parameter_schema["items"] = {"type": element_type.value}
|
||||||
else:
|
else:
|
||||||
parameter_schema["type"] = parameter.type
|
parameter_schema["type"] = parameter.type.value
|
||||||
|
|
||||||
if parameter.options:
|
if parameter.options:
|
||||||
parameter_schema["enum"] = parameter.options
|
parameter_schema["enum"] = parameter.options
|
||||||
|
|||||||
@ -5,6 +5,8 @@ import uuid
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||||
@ -63,6 +65,7 @@ from .prompts import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from dify_graph.entities import GraphInitParams
|
from dify_graph.entities import GraphInitParams
|
||||||
@ -70,7 +73,7 @@ if TYPE_CHECKING:
|
|||||||
from dify_graph.runtime import GraphRuntimeState
|
from dify_graph.runtime import GraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
def extract_json(text):
|
def extract_json(text: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
From a given JSON started from '{' or '[' extract the complete JSON object.
|
From a given JSON started from '{' or '[' extract the complete JSON object.
|
||||||
"""
|
"""
|
||||||
@ -396,10 +399,15 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# generate tool
|
# generate tool
|
||||||
|
parameter_schema = node_data.get_parameter_json_schema()
|
||||||
tool = PromptMessageTool(
|
tool = PromptMessageTool(
|
||||||
name=FUNCTION_CALLING_EXTRACTOR_NAME,
|
name=FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||||
description="Extract parameters from the natural language text",
|
description="Extract parameters from the natural language text",
|
||||||
parameters=node_data.get_parameter_json_schema(),
|
parameters={
|
||||||
|
"type": parameter_schema["type"],
|
||||||
|
"properties": dict(parameter_schema["properties"]),
|
||||||
|
"required": list(parameter_schema["required"]),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_messages, [tool]
|
return prompt_messages, [tool]
|
||||||
@ -606,19 +614,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
|
def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
|
||||||
"""
|
"""
|
||||||
Transform result into standard format.
|
Transform result into standard format.
|
||||||
"""
|
"""
|
||||||
transformed_result: dict[str, Any] = {}
|
transformed_result: dict[str, object] = {}
|
||||||
for parameter in data.parameters:
|
for parameter in data.parameters:
|
||||||
if parameter.name in result:
|
if parameter.name in result:
|
||||||
param_value = result[parameter.name]
|
param_value = result[parameter.name]
|
||||||
# transform value
|
# transform value
|
||||||
if parameter.type == SegmentType.NUMBER:
|
if parameter.type == SegmentType.NUMBER:
|
||||||
transformed = self._transform_number(param_value)
|
if isinstance(param_value, (bool, int, float, str)):
|
||||||
if transformed is not None:
|
numeric_value: bool | int | float | str = param_value
|
||||||
transformed_result[parameter.name] = transformed
|
transformed = self._transform_number(numeric_value)
|
||||||
|
if transformed is not None:
|
||||||
|
transformed_result[parameter.name] = transformed
|
||||||
elif parameter.type == SegmentType.BOOLEAN:
|
elif parameter.type == SegmentType.BOOLEAN:
|
||||||
if isinstance(result[parameter.name], (bool, int)):
|
if isinstance(result[parameter.name], (bool, int)):
|
||||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||||
@ -665,7 +675,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
|||||||
|
|
||||||
return transformed_result
|
return transformed_result
|
||||||
|
|
||||||
def _extract_complete_json_response(self, result: str) -> dict | None:
|
def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
|
||||||
"""
|
"""
|
||||||
Extract complete json response.
|
Extract complete json response.
|
||||||
"""
|
"""
|
||||||
@ -676,11 +686,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
|||||||
json_str = extract_json(result[idx:])
|
json_str = extract_json(result[idx:])
|
||||||
if json_str:
|
if json_str:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
return cast(dict, json.loads(json_str))
|
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
|
||||||
logger.info("extra error: %s", result)
|
logger.info("extra error: %s", result)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
|
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
|
||||||
"""
|
"""
|
||||||
Extract json from tool call.
|
Extract json from tool call.
|
||||||
"""
|
"""
|
||||||
@ -694,16 +704,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
|||||||
json_str = extract_json(result[idx:])
|
json_str = extract_json(result[idx:])
|
||||||
if json_str:
|
if json_str:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
return cast(dict, json.loads(json_str))
|
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
|
||||||
|
|
||||||
logger.info("extra error: %s", result)
|
logger.info("extra error: %s", result)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _generate_default_result(self, data: ParameterExtractorNodeData):
|
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
|
||||||
"""
|
"""
|
||||||
Generate default result.
|
Generate default result.
|
||||||
"""
|
"""
|
||||||
result: dict[str, Any] = {}
|
result: dict[str, object] = {}
|
||||||
for parameter in data.parameters:
|
for parameter in data.parameters:
|
||||||
if parameter.type == "number":
|
if parameter.type == "number":
|
||||||
result[parameter.name] = 0
|
result[parameter.name] = 0
|
||||||
|
|||||||
@ -1,12 +1,66 @@
|
|||||||
from typing import Any, Literal, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
from pydantic import BaseModel, field_validator
|
from typing import Literal, TypeAlias, cast
|
||||||
|
|
||||||
|
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||||
from pydantic_core.core_schema import ValidationInfo
|
from pydantic_core.core_schema import ValidationInfo
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from dify_graph.entities.base_node_data import BaseNodeData
|
from dify_graph.entities.base_node_data import BaseNodeData
|
||||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||||
|
|
||||||
|
ToolConfigurationValue: TypeAlias = str | int | float | bool
|
||||||
|
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
|
||||||
|
VariableSelector: TypeAlias = list[str]
|
||||||
|
|
||||||
|
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
|
||||||
|
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
|
||||||
|
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowToolInputValue(TypedDict):
|
||||||
|
type: Literal["mixed", "variable", "constant"]
|
||||||
|
value: ToolInputConstantValue | VariableSelector
|
||||||
|
|
||||||
|
|
||||||
|
ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue
|
||||||
|
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolInputPayload(BaseModel):
|
||||||
|
type: Literal["mixed", "variable", "constant"]
|
||||||
|
value: ToolInputConstantValue | VariableSelector
|
||||||
|
|
||||||
|
@field_validator("value", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_value(
|
||||||
|
cls, value: object, validation_info: ValidationInfo
|
||||||
|
) -> ToolInputConstantValue | VariableSelector:
|
||||||
|
input_type = validation_info.data.get("type")
|
||||||
|
if input_type == "mixed":
|
||||||
|
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
|
||||||
|
if input_type == "variable":
|
||||||
|
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||||
|
if input_type == "constant":
|
||||||
|
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
|
||||||
|
raise ValueError(f"Unknown tool input type: {input_type}")
|
||||||
|
|
||||||
|
def require_variable_selector(self) -> VariableSelector:
|
||||||
|
if self.type != "variable":
|
||||||
|
raise ValueError(f"Expected variable tool input, got {self.type}")
|
||||||
|
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry:
|
||||||
|
if isinstance(value, (str, int, float, bool)):
|
||||||
|
return cast(ToolConfigurationEntry, value)
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump())
|
||||||
|
|
||||||
|
raise TypeError("Tool configuration values must be primitives or workflow tool input objects")
|
||||||
|
|
||||||
|
|
||||||
class ToolEntity(BaseModel):
|
class ToolEntity(BaseModel):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
@ -14,52 +68,29 @@ class ToolEntity(BaseModel):
|
|||||||
provider_name: str # redundancy
|
provider_name: str # redundancy
|
||||||
tool_name: str
|
tool_name: str
|
||||||
tool_label: str # redundancy
|
tool_label: str # redundancy
|
||||||
tool_configurations: dict[str, Any]
|
tool_configurations: ToolConfigurations
|
||||||
credential_id: str | None = None
|
credential_id: str | None = None
|
||||||
plugin_unique_identifier: str | None = None # redundancy
|
plugin_unique_identifier: str | None = None # redundancy
|
||||||
|
|
||||||
@field_validator("tool_configurations", mode="before")
|
@field_validator("tool_configurations", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
raise ValueError("tool_configurations must be a dictionary")
|
raise TypeError("tool_configurations must be a dictionary")
|
||||||
|
|
||||||
for key in values.data.get("tool_configurations", {}):
|
normalized: ToolConfigurations = {}
|
||||||
value = values.data.get("tool_configurations", {}).get(key)
|
for key, item in value.items():
|
||||||
if not isinstance(value, str | int | float | bool):
|
if not isinstance(key, str):
|
||||||
raise ValueError(f"{key} must be a string")
|
raise TypeError("tool_configurations keys must be strings")
|
||||||
|
normalized[key] = _validate_tool_configuration_entry(item)
|
||||||
return value
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||||
type: NodeType = BuiltinNodeTypes.TOOL
|
type: NodeType = BuiltinNodeTypes.TOOL
|
||||||
|
|
||||||
class ToolInput(BaseModel):
|
class ToolInput(ToolInputPayload):
|
||||||
# TODO: check this type
|
pass
|
||||||
value: Union[Any, list[str]]
|
|
||||||
type: Literal["mixed", "variable", "constant"]
|
|
||||||
|
|
||||||
@field_validator("type", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_type(cls, value, validation_info: ValidationInfo):
|
|
||||||
typ = value
|
|
||||||
value = validation_info.data.get("value")
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
return typ
|
|
||||||
|
|
||||||
if typ == "mixed" and not isinstance(value, str):
|
|
||||||
raise ValueError("value must be a string")
|
|
||||||
elif typ == "variable":
|
|
||||||
if not isinstance(value, list):
|
|
||||||
raise ValueError("value must be a list")
|
|
||||||
for val in value:
|
|
||||||
if not isinstance(val, str):
|
|
||||||
raise ValueError("value must be a list of strings")
|
|
||||||
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
|
|
||||||
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
|
|
||||||
return typ
|
|
||||||
|
|
||||||
tool_parameters: dict[str, ToolInput]
|
tool_parameters: dict[str, ToolInput]
|
||||||
# The version of the tool parameter.
|
# The version of the tool parameter.
|
||||||
@ -69,7 +100,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
|||||||
|
|
||||||
@field_validator("tool_parameters", mode="before")
|
@field_validator("tool_parameters", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def filter_none_tool_inputs(cls, value):
|
def filter_none_tool_inputs(cls, value: object) -> object:
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -80,8 +111,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _has_valid_value(tool_input):
|
def _has_valid_value(tool_input: object) -> bool:
|
||||||
"""Check if the value is valid"""
|
"""Check if the value is valid"""
|
||||||
if isinstance(tool_input, dict):
|
if isinstance(tool_input, dict):
|
||||||
return tool_input.get("value") is not None
|
return tool_input.get("value") is not None
|
||||||
return getattr(tool_input, "value", None) is not None
|
if isinstance(tool_input, ToolNodeData.ToolInput):
|
||||||
|
return tool_input.value is not None
|
||||||
|
return False
|
||||||
|
|||||||
@ -225,10 +225,11 @@ class ToolNode(Node[ToolNodeData]):
|
|||||||
continue
|
continue
|
||||||
tool_input = node_data.tool_parameters[parameter_name]
|
tool_input = node_data.tool_parameters[parameter_name]
|
||||||
if tool_input.type == "variable":
|
if tool_input.type == "variable":
|
||||||
variable = variable_pool.get(tool_input.value)
|
variable_selector = tool_input.require_variable_selector()
|
||||||
|
variable = variable_pool.get(variable_selector)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
if parameter.required:
|
if parameter.required:
|
||||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
raise ToolParameterError(f"Variable {variable_selector} does not exist")
|
||||||
continue
|
continue
|
||||||
parameter_value = variable.value
|
parameter_value = variable.value
|
||||||
elif tool_input.type in {"mixed", "constant"}:
|
elif tool_input.type in {"mixed", "constant"}:
|
||||||
@ -510,8 +511,9 @@ class ToolNode(Node[ToolNodeData]):
|
|||||||
for selector in selectors:
|
for selector in selectors:
|
||||||
result[selector.variable] = selector.value_selector
|
result[selector.variable] = selector.value_selector
|
||||||
case "variable":
|
case "variable":
|
||||||
selector_key = ".".join(input.value)
|
variable_selector = input.require_variable_selector()
|
||||||
result[f"#{selector_key}#"] = input.value
|
selector_key = ".".join(variable_selector)
|
||||||
|
result[f"#{selector_key}#"] = variable_selector
|
||||||
case "constant":
|
case "constant":
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from dify_graph.node_events import NodeRunResult
|
|||||||
from dify_graph.nodes.base.node import Node
|
from dify_graph.nodes.base.node import Node
|
||||||
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
|
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
|
||||||
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||||
from dify_graph.variables import SegmentType, VariableBase
|
from dify_graph.variables import Segment, SegmentType, VariableBase
|
||||||
|
|
||||||
from .node_data import VariableAssignerData, WriteMode
|
from .node_data import VariableAssignerData, WriteMode
|
||||||
|
|
||||||
@ -74,23 +74,29 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
|||||||
if not isinstance(original_variable, VariableBase):
|
if not isinstance(original_variable, VariableBase):
|
||||||
raise VariableOperatorNodeError("assigned variable not found")
|
raise VariableOperatorNodeError("assigned variable not found")
|
||||||
|
|
||||||
|
income_value: Segment
|
||||||
|
updated_variable: VariableBase
|
||||||
match self.node_data.write_mode:
|
match self.node_data.write_mode:
|
||||||
case WriteMode.OVER_WRITE:
|
case WriteMode.OVER_WRITE:
|
||||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||||
if not income_value:
|
if input_value is None:
|
||||||
raise VariableOperatorNodeError("input value not found")
|
raise VariableOperatorNodeError("input value not found")
|
||||||
|
income_value = input_value
|
||||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||||
|
|
||||||
case WriteMode.APPEND:
|
case WriteMode.APPEND:
|
||||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||||
if not income_value:
|
if input_value is None:
|
||||||
raise VariableOperatorNodeError("input value not found")
|
raise VariableOperatorNodeError("input value not found")
|
||||||
|
income_value = input_value
|
||||||
updated_value = original_variable.value + [income_value.value]
|
updated_value = original_variable.value + [income_value.value]
|
||||||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||||
|
|
||||||
case WriteMode.CLEAR:
|
case WriteMode.CLEAR:
|
||||||
income_value = SegmentType.get_zero_value(original_variable.value_type)
|
income_value = SegmentType.get_zero_value(original_variable.value_type)
|
||||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||||
|
case _:
|
||||||
|
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||||
|
|
||||||
# Over write the variable.
|
# Over write the variable.
|
||||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||||
|
|||||||
@ -66,6 +66,11 @@ class GraphExecutionProtocol(Protocol):
|
|||||||
exceptions_count: int
|
exceptions_count: int
|
||||||
pause_reasons: list[PauseReason]
|
pause_reasons: list[PauseReason]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
|
||||||
|
"""Return node execution state keyed by node id for resume support."""
|
||||||
|
...
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Transition execution into the running state."""
|
"""Transition execution into the running state."""
|
||||||
...
|
...
|
||||||
@ -91,6 +96,12 @@ class GraphExecutionProtocol(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class NodeExecutionProtocol(Protocol):
|
||||||
|
"""Structural interface for per-node execution state used during resume."""
|
||||||
|
|
||||||
|
execution_id: str | None
|
||||||
|
|
||||||
|
|
||||||
class ResponseStreamCoordinatorProtocol(Protocol):
|
class ResponseStreamCoordinatorProtocol(Protocol):
|
||||||
"""Structural interface for response stream coordinator."""
|
"""Structural interface for response stream coordinator."""
|
||||||
|
|
||||||
|
|||||||
@ -13,21 +13,6 @@ controllers/console/workspace/trigger_providers.py
|
|||||||
controllers/service_api/app/annotation.py
|
controllers/service_api/app/annotation.py
|
||||||
controllers/web/workflow_events.py
|
controllers/web/workflow_events.py
|
||||||
core/agent/fc_agent_runner.py
|
core/agent/fc_agent_runner.py
|
||||||
core/app/apps/advanced_chat/app_generator.py
|
|
||||||
core/app/apps/advanced_chat/app_runner.py
|
|
||||||
core/app/apps/advanced_chat/generate_task_pipeline.py
|
|
||||||
core/app/apps/agent_chat/app_generator.py
|
|
||||||
core/app/apps/base_app_generate_response_converter.py
|
|
||||||
core/app/apps/base_app_generator.py
|
|
||||||
core/app/apps/chat/app_generator.py
|
|
||||||
core/app/apps/common/workflow_response_converter.py
|
|
||||||
core/app/apps/completion/app_generator.py
|
|
||||||
core/app/apps/pipeline/pipeline_generator.py
|
|
||||||
core/app/apps/pipeline/pipeline_runner.py
|
|
||||||
core/app/apps/workflow/app_generator.py
|
|
||||||
core/app/apps/workflow/app_runner.py
|
|
||||||
core/app/apps/workflow/generate_task_pipeline.py
|
|
||||||
core/app/apps/workflow_app_runner.py
|
|
||||||
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
|
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
|
||||||
core/datasource/datasource_manager.py
|
core/datasource/datasource_manager.py
|
||||||
core/external_data_tool/api/api.py
|
core/external_data_tool/api/api.py
|
||||||
@ -108,35 +93,6 @@ core/tools/workflow_as_tool/provider.py
|
|||||||
core/trigger/debug/event_selectors.py
|
core/trigger/debug/event_selectors.py
|
||||||
core/trigger/entities/entities.py
|
core/trigger/entities/entities.py
|
||||||
core/trigger/provider.py
|
core/trigger/provider.py
|
||||||
core/workflow/workflow_entry.py
|
|
||||||
dify_graph/entities/workflow_execution.py
|
|
||||||
dify_graph/file/file_manager.py
|
|
||||||
dify_graph/graph_engine/error_handler.py
|
|
||||||
dify_graph/graph_engine/layers/execution_limits.py
|
|
||||||
dify_graph/nodes/agent/agent_node.py
|
|
||||||
dify_graph/nodes/base/node.py
|
|
||||||
dify_graph/nodes/code/code_node.py
|
|
||||||
dify_graph/nodes/datasource/datasource_node.py
|
|
||||||
dify_graph/nodes/document_extractor/node.py
|
|
||||||
dify_graph/nodes/human_input/human_input_node.py
|
|
||||||
dify_graph/nodes/if_else/if_else_node.py
|
|
||||||
dify_graph/nodes/iteration/iteration_node.py
|
|
||||||
dify_graph/nodes/knowledge_index/knowledge_index_node.py
|
|
||||||
core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
|
|
||||||
dify_graph/nodes/list_operator/node.py
|
|
||||||
dify_graph/nodes/llm/node.py
|
|
||||||
dify_graph/nodes/loop/loop_node.py
|
|
||||||
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
|
|
||||||
dify_graph/nodes/question_classifier/question_classifier_node.py
|
|
||||||
dify_graph/nodes/start/start_node.py
|
|
||||||
dify_graph/nodes/template_transform/template_transform_node.py
|
|
||||||
dify_graph/nodes/tool/tool_node.py
|
|
||||||
dify_graph/nodes/trigger_plugin/trigger_event_node.py
|
|
||||||
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
|
|
||||||
dify_graph/nodes/trigger_webhook/node.py
|
|
||||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
|
||||||
dify_graph/nodes/variable_assigner/v1/node.py
|
|
||||||
dify_graph/nodes/variable_assigner/v2/node.py
|
|
||||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||||
extensions/otel/instrumentation.py
|
extensions/otel/instrumentation.py
|
||||||
extensions/otel/runtime.py
|
extensions/otel/runtime.py
|
||||||
|
|||||||
@ -1013,7 +1013,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
|||||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
|
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
|
||||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object()))
|
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object()))
|
||||||
|
|
||||||
refreshed = _refresh_model(session=SimpleNamespace(), model=source_model)
|
refreshed = _refresh_model(session=None, model=source_model)
|
||||||
|
|
||||||
assert refreshed is detached_model
|
assert refreshed is detached_model
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,110 @@
|
|||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.app.entities.task_entities import AppBlockingResponse
|
||||||
|
from core.errors.error import QuotaExceededError
|
||||||
|
|
||||||
|
|
||||||
|
class DummyResponseConverter(AppGenerateResponseConverter):
|
||||||
|
_blocking_response_type = AppBlockingResponse
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, str]:
|
||||||
|
return {"mode": "blocking-full", "task_id": blocking_response.task_id}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, str]:
|
||||||
|
return {"mode": "blocking-simple", "task_id": blocking_response.task_id}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_stream_full_response(cls, stream_response: Iterator[object]):
|
||||||
|
for _ in stream_response:
|
||||||
|
yield {"mode": "stream-full"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_stream_simple_response(cls, stream_response: Iterator[object]):
|
||||||
|
for _ in stream_response:
|
||||||
|
yield {"mode": "stream-simple"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_routes_to_full_or_simple_modes() -> None:
|
||||||
|
blocking = AppBlockingResponse(task_id="task-1")
|
||||||
|
|
||||||
|
assert DummyResponseConverter.convert(blocking, InvokeFrom.DEBUGGER) == {
|
||||||
|
"mode": "blocking-full",
|
||||||
|
"task_id": "task-1",
|
||||||
|
}
|
||||||
|
assert DummyResponseConverter.convert(blocking, InvokeFrom.WEB_APP) == {
|
||||||
|
"mode": "blocking-simple",
|
||||||
|
"task_id": "task-1",
|
||||||
|
}
|
||||||
|
assert list(DummyResponseConverter.convert(iter([object()]), InvokeFrom.SERVICE_API)) == [{"mode": "stream-full"}]
|
||||||
|
assert list(DummyResponseConverter.convert(iter([object()]), InvokeFrom.WEB_APP)) == [{"mode": "stream-simple"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_simple_metadata_preserves_new_retriever_fields() -> None:
|
||||||
|
metadata = {
|
||||||
|
"retriever_resources": [
|
||||||
|
{
|
||||||
|
"dataset_id": "dataset-1",
|
||||||
|
"dataset_name": "Dataset",
|
||||||
|
"document_id": "document-1",
|
||||||
|
"segment_id": "segment-1",
|
||||||
|
"position": 1,
|
||||||
|
"data_source_type": "upload_file",
|
||||||
|
"document_name": "Document",
|
||||||
|
"score": 0.9,
|
||||||
|
"hit_count": 2,
|
||||||
|
"word_count": 128,
|
||||||
|
"segment_position": 3,
|
||||||
|
"index_node_hash": "hash",
|
||||||
|
"content": "content",
|
||||||
|
"page": 5,
|
||||||
|
"title": "Title",
|
||||||
|
"files": [{"id": "file-1"}],
|
||||||
|
"summary": "summary",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"annotation_reply": "hidden",
|
||||||
|
"usage": {"latency": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = DummyResponseConverter._get_simple_metadata(metadata)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"retriever_resources": [
|
||||||
|
{
|
||||||
|
"dataset_id": "dataset-1",
|
||||||
|
"dataset_name": "Dataset",
|
||||||
|
"document_id": "document-1",
|
||||||
|
"segment_id": "segment-1",
|
||||||
|
"position": 1,
|
||||||
|
"data_source_type": "upload_file",
|
||||||
|
"document_name": "Document",
|
||||||
|
"score": 0.9,
|
||||||
|
"hit_count": 2,
|
||||||
|
"word_count": 128,
|
||||||
|
"segment_position": 3,
|
||||||
|
"index_node_hash": "hash",
|
||||||
|
"content": "content",
|
||||||
|
"page": 5,
|
||||||
|
"title": "Title",
|
||||||
|
"files": [{"id": "file-1"}],
|
||||||
|
"summary": "summary",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_to_stream_response_uses_specific_and_fallback_mappings() -> None:
|
||||||
|
quota_response = DummyResponseConverter._error_to_stream_response(QuotaExceededError())
|
||||||
|
fallback_response = DummyResponseConverter._error_to_stream_response(RuntimeError("boom"))
|
||||||
|
|
||||||
|
assert quota_response["code"] == "provider_quota_exceeded"
|
||||||
|
assert quota_response["status"] == 400
|
||||||
|
assert fallback_response == {
|
||||||
|
"code": "internal_server_error",
|
||||||
|
"message": "Internal Server Error, please contact support.",
|
||||||
|
"status": 500,
|
||||||
|
}
|
||||||
@ -33,6 +33,79 @@ from dify_graph.system_variable import SystemVariable
|
|||||||
|
|
||||||
|
|
||||||
class TestWorkflowBasedAppRunner:
|
class TestWorkflowBasedAppRunner:
|
||||||
|
def test_get_graph_items_rejects_non_mapping_entries(self):
|
||||||
|
with pytest.raises(ValueError, match="nodes in workflow graph must be mappings"):
|
||||||
|
WorkflowBasedAppRunner._get_graph_items({"nodes": ["bad"], "edges": []})
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="edges in workflow graph must be mappings"):
|
||||||
|
WorkflowBasedAppRunner._get_graph_items({"nodes": [], "edges": ["bad"]})
|
||||||
|
|
||||||
|
def test_extract_start_node_id_handles_missing_and_invalid_values(self):
|
||||||
|
assert WorkflowBasedAppRunner._extract_start_node_id(None) is None
|
||||||
|
assert WorkflowBasedAppRunner._extract_start_node_id({"data": "invalid"}) is None
|
||||||
|
assert WorkflowBasedAppRunner._extract_start_node_id({"data": {"start_node_id": 123}}) is None
|
||||||
|
assert WorkflowBasedAppRunner._extract_start_node_id({"data": {"start_node_id": "start-node"}}) == "start-node"
|
||||||
|
|
||||||
|
def test_build_single_node_graph_config_keeps_target_related_and_start_nodes(self):
|
||||||
|
graph_config, target_node_config = WorkflowBasedAppRunner._build_single_node_graph_config(
|
||||||
|
graph_config={
|
||||||
|
"nodes": [
|
||||||
|
{"id": "start-node", "data": {"type": "start", "version": "1"}},
|
||||||
|
{
|
||||||
|
"id": "loop-node",
|
||||||
|
"data": {"type": "loop", "version": "1", "start_node_id": "start-node"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "loop-child",
|
||||||
|
"data": {"type": "answer", "version": "1", "loop_id": "loop-node"},
|
||||||
|
},
|
||||||
|
{"id": "outside-node", "data": {"type": "answer", "version": "1"}},
|
||||||
|
],
|
||||||
|
"edges": [
|
||||||
|
{"source": "start-node", "target": "loop-node"},
|
||||||
|
{"source": "loop-node", "target": "loop-child"},
|
||||||
|
{"source": "loop-node", "target": "outside-node"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
node_id="loop-node",
|
||||||
|
node_type_filter_key="loop_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [node["id"] for node in graph_config["nodes"]] == ["start-node", "loop-node", "loop-child"]
|
||||||
|
assert graph_config["edges"] == [
|
||||||
|
{"source": "start-node", "target": "loop-node"},
|
||||||
|
{"source": "loop-node", "target": "loop-child"},
|
||||||
|
]
|
||||||
|
assert target_node_config["id"] == "loop-node"
|
||||||
|
|
||||||
|
def test_build_agent_strategy_info_validates_payload(self):
|
||||||
|
event = NodeRunStartedEvent(
|
||||||
|
id="exec",
|
||||||
|
node_id="node",
|
||||||
|
node_type=BuiltinNodeTypes.START,
|
||||||
|
node_title="Start",
|
||||||
|
start_at=datetime.utcnow(),
|
||||||
|
extras={"agent_strategy": {"name": "planner", "icon": "robot"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = WorkflowBasedAppRunner._build_agent_strategy_info(event)
|
||||||
|
|
||||||
|
assert strategy is not None
|
||||||
|
assert strategy.name == "planner"
|
||||||
|
assert strategy.icon == "robot"
|
||||||
|
|
||||||
|
def test_build_agent_strategy_info_returns_none_for_invalid_payload(self):
|
||||||
|
event = NodeRunStartedEvent(
|
||||||
|
id="exec",
|
||||||
|
node_id="node",
|
||||||
|
node_type=BuiltinNodeTypes.START,
|
||||||
|
node_title="Start",
|
||||||
|
start_at=datetime.utcnow(),
|
||||||
|
extras={"agent_strategy": {"name": "planner", "extra": "ignored"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert WorkflowBasedAppRunner._build_agent_strategy_info(event) is None
|
||||||
|
|
||||||
def test_resolve_user_from(self):
|
def test_resolve_user_from(self):
|
||||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||||
|
|
||||||
@ -174,6 +247,34 @@ class TestWorkflowBasedAppRunner:
|
|||||||
assert paused_event.paused_nodes == ["node-1"]
|
assert paused_event.paused_nodes == ["node-1"]
|
||||||
assert emails
|
assert emails
|
||||||
|
|
||||||
|
def test_enqueue_human_input_notifications_skips_invalid_reasons_and_logs_failures(self, monkeypatch):
|
||||||
|
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||||
|
|
||||||
|
seen_calls: list[tuple[dict[str, object], str]] = []
|
||||||
|
|
||||||
|
class _Dispatch:
|
||||||
|
def apply_async(self, *, kwargs, queue):
|
||||||
|
seen_calls.append((kwargs, queue))
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
logged: list[str] = []
|
||||||
|
monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", _Dispatch())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"core.app.apps.workflow_app_runner.logger",
|
||||||
|
SimpleNamespace(exception=lambda message, form_id: logged.append(f"{message}:{form_id}")),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner._enqueue_human_input_notifications(
|
||||||
|
[
|
||||||
|
object(),
|
||||||
|
HumanInputRequired(form_id="", form_content="content", node_id="node", node_title="Node"),
|
||||||
|
HumanInputRequired(form_id="form-1", form_content="content", node_id="node", node_title="Node"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert seen_calls == [({"form_id": "form-1", "node_title": "Node"}, "mail")]
|
||||||
|
assert logged == ["Failed to enqueue human input email task for form %s:form-1"]
|
||||||
|
|
||||||
def test_handle_node_events_publishes_queue_events(self):
|
def test_handle_node_events_publishes_queue_events(self):
|
||||||
published: list[object] = []
|
published: list[object] = []
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from copy import deepcopy
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -33,8 +33,8 @@ def _make_graph_state():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_run_uses_single_node_execution_branch(
|
def test_run_uses_single_node_execution_branch(
|
||||||
single_iteration_run: Any,
|
single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
|
||||||
single_loop_run: Any,
|
single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
app_config = MagicMock()
|
app_config = MagicMock()
|
||||||
app_config.app_id = "app"
|
app_config.app_id = "app"
|
||||||
@ -130,10 +130,23 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
|||||||
"break_conditions": [],
|
"break_conditions": [],
|
||||||
"logical_operator": "and",
|
"logical_operator": "and",
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "other-node",
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
"title": "Answer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"source": "other-node",
|
||||||
|
"target": "loop-node",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"edges": [],
|
|
||||||
}
|
}
|
||||||
|
original_graph_dict = deepcopy(workflow.graph_dict)
|
||||||
|
|
||||||
_, _, graph_runtime_state = _make_graph_state()
|
_, _, graph_runtime_state = _make_graph_state()
|
||||||
seen_configs: list[object] = []
|
seen_configs: list[object] = []
|
||||||
@ -143,13 +156,19 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
|||||||
seen_configs.append(value)
|
seen_configs.append(value)
|
||||||
return original_validate_python(value)
|
return original_validate_python(value)
|
||||||
|
|
||||||
|
class FakeNodeClass:
|
||||||
|
@staticmethod
|
||||||
|
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
|
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
|
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
|
||||||
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
|
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
|
||||||
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
|
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
|
||||||
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
|
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
|
||||||
|
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
|
||||||
):
|
):
|
||||||
runner._get_graph_and_variable_pool_for_single_node_run(
|
runner._get_graph_and_variable_pool_for_single_node_run(
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
@ -161,3 +180,8 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert seen_configs == [workflow.graph_dict["nodes"][0]]
|
assert seen_configs == [workflow.graph_dict["nodes"][0]]
|
||||||
|
assert workflow.graph_dict == original_graph_dict
|
||||||
|
graph_config = graph_init.call_args.kwargs["graph_config"]
|
||||||
|
assert graph_config is not workflow.graph_dict
|
||||||
|
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
|
||||||
|
assert graph_config["edges"] == []
|
||||||
|
|||||||
@ -0,0 +1,46 @@
|
|||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_input_accepts_variable_selector_and_mixed_values() -> None:
|
||||||
|
node_data = AgentNodeData.model_validate(
|
||||||
|
{
|
||||||
|
"title": "Agent",
|
||||||
|
"agent_strategy_provider_name": "provider",
|
||||||
|
"agent_strategy_name": "strategy",
|
||||||
|
"agent_strategy_label": "Strategy",
|
||||||
|
"agent_parameters": {
|
||||||
|
"query": {"type": "variable", "value": ["start", "query"]},
|
||||||
|
"tools": {"type": "mixed", "value": [{"provider": "builtin", "name": "search"}]},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert node_data.agent_parameters["query"].value == ["start", "query"]
|
||||||
|
assert node_data.agent_parameters["tools"].value == [{"provider": "builtin", "name": "search"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_input_rejects_invalid_variable_selector_and_unknown_type() -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AgentNodeData.model_validate(
|
||||||
|
{
|
||||||
|
"title": "Agent",
|
||||||
|
"agent_strategy_provider_name": "provider",
|
||||||
|
"agent_strategy_name": "strategy",
|
||||||
|
"agent_strategy_label": "Strategy",
|
||||||
|
"agent_parameters": {"query": {"type": "variable", "value": "start.query"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="Unknown agent input type"):
|
||||||
|
AgentNodeData.model_validate(
|
||||||
|
{
|
||||||
|
"title": "Agent",
|
||||||
|
"agent_strategy_provider_name": "provider",
|
||||||
|
"agent_strategy_name": "strategy",
|
||||||
|
"agent_strategy_label": "Strategy",
|
||||||
|
"agent_parameters": {"query": {"type": "unsupported", "value": "hello"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
@ -0,0 +1,125 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
from core.workflow.nodes.agent.exceptions import AgentVariableNotFoundError
|
||||||
|
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_mcp_type_tool_depends_on_strategy_meta_version() -> None:
|
||||||
|
runtime_support = AgentRuntimeSupport()
|
||||||
|
tools = [
|
||||||
|
{"type": ToolProviderType.BUILT_IN, "tool_name": "search"},
|
||||||
|
{"type": ToolProviderType.MCP, "tool_name": "mcp-tool"},
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered_tools = runtime_support._filter_mcp_type_tool(SimpleNamespace(meta_version="0.0.1"), tools)
|
||||||
|
preserved_tools = runtime_support._filter_mcp_type_tool(SimpleNamespace(meta_version="0.0.2"), tools)
|
||||||
|
|
||||||
|
assert filtered_tools == [{"type": ToolProviderType.BUILT_IN, "tool_name": "search"}]
|
||||||
|
assert preserved_tools == tools
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_tool_payloads_keeps_enabled_tools_and_resolves_values() -> None:
|
||||||
|
runtime_support = AgentRuntimeSupport()
|
||||||
|
variable_pool = SimpleNamespace(get=lambda selector: SimpleNamespace(value=f"resolved:{'.'.join(selector)}"))
|
||||||
|
|
||||||
|
normalized_tools = runtime_support._normalize_tool_payloads(
|
||||||
|
strategy=SimpleNamespace(meta_version="0.0.2"),
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"tool_name": "search",
|
||||||
|
"schemas": {"ignored": True},
|
||||||
|
"parameters": {
|
||||||
|
"query": {
|
||||||
|
"auto": 0,
|
||||||
|
"value": {"type": "variable", "value": ["start", "query"]},
|
||||||
|
},
|
||||||
|
"top_k": {
|
||||||
|
"auto": 0,
|
||||||
|
"value": {"type": "constant", "value": 3},
|
||||||
|
},
|
||||||
|
"optional": {"auto": 1, "value": {"type": "constant", "value": "skip"}},
|
||||||
|
},
|
||||||
|
"settings": {
|
||||||
|
"region": {"value": "us"},
|
||||||
|
"safe": {"value": True},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"enabled": False, "tool_name": "disabled"},
|
||||||
|
],
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert normalized_tools == [
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"tool_name": "search",
|
||||||
|
"parameters": {"query": "resolved:start.query", "top_k": 3, "optional": None},
|
||||||
|
"settings": {"region": "us", "safe": True},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_tool_parameters_raises_for_missing_variable() -> None:
|
||||||
|
runtime_support = AgentRuntimeSupport()
|
||||||
|
variable_pool = SimpleNamespace(get=lambda _selector: None)
|
||||||
|
|
||||||
|
with pytest.raises(AgentVariableNotFoundError, match=r"\['start', 'query'\]"):
|
||||||
|
runtime_support._resolve_tool_parameters(
|
||||||
|
tool={
|
||||||
|
"parameters": {
|
||||||
|
"query": {
|
||||||
|
"auto": 0,
|
||||||
|
"value": {"type": "variable", "value": ["start", "query"]},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_credentials_collects_valid_tool_credentials_only() -> None:
|
||||||
|
runtime_support = AgentRuntimeSupport()
|
||||||
|
|
||||||
|
credentials = runtime_support.build_credentials(
|
||||||
|
parameters={
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"credential_id": "cred-1",
|
||||||
|
"identity": {
|
||||||
|
"author": "author",
|
||||||
|
"name": "tool",
|
||||||
|
"label": {"en_US": "Tool"},
|
||||||
|
"provider": "provider-a",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"credential_id": "cred-2",
|
||||||
|
"identity": {"author": "author"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"credential_id": None,
|
||||||
|
"identity": {
|
||||||
|
"author": "author",
|
||||||
|
"name": "tool",
|
||||||
|
"label": {"en_US": "Tool"},
|
||||||
|
"provider": "provider-b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"invalid",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert credentials.tool_credentials == {"provider-a": "cred-1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_coerce_named_json_objects_requires_string_keys_and_json_object_values() -> None:
|
||||||
|
runtime_support = AgentRuntimeSupport()
|
||||||
|
|
||||||
|
assert runtime_support._coerce_named_json_objects({"valid": {"value": 1}}) == {"valid": {"value": 1}}
|
||||||
|
assert runtime_support._coerce_named_json_objects({1: {"value": 1}}) is None
|
||||||
|
assert runtime_support._coerce_named_json_objects({"invalid": object()}) is None
|
||||||
@ -13,7 +13,9 @@ from core.model_manager import ModelInstance
|
|||||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
from dify_graph.entities import GraphInitParams
|
from dify_graph.entities import GraphInitParams
|
||||||
from dify_graph.file import File, FileTransferMethod, FileType
|
from dify_graph.file import File, FileTransferMethod, FileType
|
||||||
|
from dify_graph.model_runtime.entities import LLMMode
|
||||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultWithStructuredOutput, LLMUsage
|
||||||
from dify_graph.model_runtime.entities.message_entities import (
|
from dify_graph.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
@ -55,6 +57,118 @@ class MockTokenBufferMemory:
|
|||||||
return self.history_messages
|
return self.history_messages
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_node_data_normalizes_optional_configs_and_legacy_structured_output() -> None:
|
||||||
|
node_data = LLMNodeData.model_validate(
|
||||||
|
{
|
||||||
|
"title": "Test LLM",
|
||||||
|
"model": {"provider": "openai", "name": "gpt-4o-mini", "mode": LLMMode.CHAT, "completion_params": {}},
|
||||||
|
"prompt_template": [],
|
||||||
|
"prompt_config": None,
|
||||||
|
"memory": None,
|
||||||
|
"context": {"enabled": False},
|
||||||
|
"vision": {"enabled": True, "configs": None},
|
||||||
|
"structured_output": {
|
||||||
|
"schema": {"type": "object"},
|
||||||
|
"name": "Response",
|
||||||
|
"description": "Structured",
|
||||||
|
},
|
||||||
|
"structured_output_enabled": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert node_data.prompt_config.jinja2_variables == []
|
||||||
|
assert node_data.vision.configs.variable_selector == ["sys", "files"]
|
||||||
|
assert node_data.structured_output == {
|
||||||
|
"schema": {"type": "object"},
|
||||||
|
"name": "Response",
|
||||||
|
"description": "Structured",
|
||||||
|
}
|
||||||
|
assert node_data.structured_output_enabled is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_node_data_discards_legacy_structured_output_without_schema() -> None:
|
||||||
|
node_data = LLMNodeData.model_validate(
|
||||||
|
{
|
||||||
|
"title": "Test LLM",
|
||||||
|
"model": {"provider": "openai", "name": "gpt-4o-mini", "mode": LLMMode.CHAT, "completion_params": {}},
|
||||||
|
"prompt_template": [],
|
||||||
|
"memory": None,
|
||||||
|
"context": {"enabled": False},
|
||||||
|
"vision": {"enabled": False},
|
||||||
|
"structured_output": {"name": "Missing schema"},
|
||||||
|
"structured_output_enabled": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert node_data.structured_output is None
|
||||||
|
assert node_data.structured_output_enabled is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_config_converts_none_jinja_variables() -> None:
|
||||||
|
prompt_config = LLMNodeData.model_validate(
|
||||||
|
{
|
||||||
|
"title": "Test LLM",
|
||||||
|
"model": {"provider": "openai", "name": "gpt-4o-mini", "mode": LLMMode.CHAT, "completion_params": {}},
|
||||||
|
"prompt_template": [],
|
||||||
|
"prompt_config": None,
|
||||||
|
"memory": None,
|
||||||
|
"context": {"enabled": False},
|
||||||
|
"vision": {"enabled": False},
|
||||||
|
"structured_output_enabled": False,
|
||||||
|
}
|
||||||
|
).prompt_config
|
||||||
|
|
||||||
|
assert prompt_config.jinja2_variables == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_structured_output_schema_validates_required_object_shape() -> None:
|
||||||
|
assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object", "a": 1}}) == {
|
||||||
|
"type": "object",
|
||||||
|
"a": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="valid structured output schema"):
|
||||||
|
LLMNode.fetch_structured_output_schema(structured_output={"schema": None})
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_blocking_result_separates_reasoning_and_structured_output() -> None:
|
||||||
|
saver = mock.MagicMock(spec=LLMFileSaver)
|
||||||
|
event = LLMNode.handle_blocking_result(
|
||||||
|
invoke_result=LLMResultWithStructuredOutput(
|
||||||
|
model="gpt",
|
||||||
|
message=AssistantPromptMessage(content="<think>reasoning</think>answer"),
|
||||||
|
usage=LLMUsage.empty_usage(),
|
||||||
|
structured_output={"answer": "done"},
|
||||||
|
),
|
||||||
|
saver=saver,
|
||||||
|
file_outputs=[],
|
||||||
|
reasoning_format="separated",
|
||||||
|
request_latency=1.2345,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.text == "answer"
|
||||||
|
assert event.reasoning_content == "reasoning"
|
||||||
|
assert event.structured_output == {"answer": "done"}
|
||||||
|
assert event.usage.latency == 1.234
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_blocking_result_keeps_tagged_text_without_structured_output() -> None:
|
||||||
|
saver = mock.MagicMock(spec=LLMFileSaver)
|
||||||
|
event = LLMNode.handle_blocking_result(
|
||||||
|
invoke_result=LLMResult(
|
||||||
|
model="gpt",
|
||||||
|
message=AssistantPromptMessage(content="plain text"),
|
||||||
|
usage=LLMUsage.empty_usage(),
|
||||||
|
),
|
||||||
|
saver=saver,
|
||||||
|
file_outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.text == "plain text"
|
||||||
|
assert event.reasoning_content == ""
|
||||||
|
assert event.structured_output is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llm_node_data() -> LLMNodeData:
|
def llm_node_data() -> LLMNodeData:
|
||||||
return LLMNodeData(
|
return LLMNodeData(
|
||||||
|
|||||||
@ -1,6 +1,12 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||||
from dify_graph.nodes.loop.entities import LoopNodeData
|
from dify_graph.nodes.loop.entities import LoopNodeData, LoopValue
|
||||||
from dify_graph.nodes.loop.loop_node import LoopNode
|
from dify_graph.nodes.loop.loop_node import LoopNode
|
||||||
|
from dify_graph.variables.types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
|
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
|
||||||
@ -50,3 +56,104 @@ def test_extract_variable_selector_to_variable_mapping_validates_child_node_conf
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert seen_configs == [child_node_config]
|
assert seen_configs == [child_node_config]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("var_type", "original_value", "expected_value"),
|
||||||
|
[
|
||||||
|
(SegmentType.ARRAY_STRING, ["alpha", "beta"], ["alpha", "beta"]),
|
||||||
|
(SegmentType.ARRAY_NUMBER, [1, 2.5], [1, 2.5]),
|
||||||
|
(SegmentType.ARRAY_OBJECT, [{"name": "item"}], [{"name": "item"}]),
|
||||||
|
(SegmentType.ARRAY_STRING, '["legacy", "json"]', ["legacy", "json"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_segment_for_constant_accepts_native_array_values(
|
||||||
|
var_type: SegmentType, original_value: LoopValue, expected_value: LoopValue
|
||||||
|
) -> None:
|
||||||
|
segment = LoopNode._get_segment_for_constant(var_type, original_value)
|
||||||
|
|
||||||
|
assert segment.value_type == var_type
|
||||||
|
assert segment.value == expected_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_loop_variable_data_validates_variable_selector_and_constant_value() -> None:
|
||||||
|
variable_input = LoopNodeData(
|
||||||
|
title="Loop",
|
||||||
|
loop_count=1,
|
||||||
|
break_conditions=[],
|
||||||
|
logical_operator="and",
|
||||||
|
loop_variables=[
|
||||||
|
{
|
||||||
|
"label": "question",
|
||||||
|
"var_type": SegmentType.STRING,
|
||||||
|
"value_type": "variable",
|
||||||
|
"value": ["start", "question"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"label": "payload",
|
||||||
|
"var_type": SegmentType.OBJECT,
|
||||||
|
"value_type": "constant",
|
||||||
|
"value": {"count": 1, "items": ["a", 2]},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert variable_input.loop_variables[0].require_variable_selector() == ["start", "question"]
|
||||||
|
assert variable_input.loop_variables[1].require_constant_value() == {"count": 1, "items": ["a", 2]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_loop_variable_data_rejects_missing_variable_selector() -> None:
|
||||||
|
with pytest.raises(ValidationError, match="Variable loop inputs require a selector"):
|
||||||
|
LoopNodeData(
|
||||||
|
title="Loop",
|
||||||
|
loop_count=1,
|
||||||
|
break_conditions=[],
|
||||||
|
logical_operator="and",
|
||||||
|
loop_variables=[
|
||||||
|
{
|
||||||
|
"label": "question",
|
||||||
|
"var_type": SegmentType.STRING,
|
||||||
|
"value_type": "variable",
|
||||||
|
"value": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_loop_node_data_outputs_default_to_empty_mapping_for_none() -> None:
|
||||||
|
node_data = LoopNodeData(
|
||||||
|
title="Loop",
|
||||||
|
loop_count=1,
|
||||||
|
break_conditions=[],
|
||||||
|
logical_operator="and",
|
||||||
|
outputs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert node_data.outputs == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_append_loop_info_to_event_preserves_existing_loop_metadata() -> None:
|
||||||
|
node = object.__new__(LoopNode)
|
||||||
|
node._node_id = "loop-node"
|
||||||
|
|
||||||
|
event = SimpleNamespace(
|
||||||
|
node_run_result=SimpleNamespace(metadata={"loop_id": "existing-loop", "other": "value"}),
|
||||||
|
in_loop_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
node._append_loop_info_to_event(event=event, loop_run_index=2)
|
||||||
|
|
||||||
|
assert event.in_loop_id == "loop-node"
|
||||||
|
assert event.node_run_result.metadata == {"loop_id": "existing-loop", "other": "value"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_loop_subgraph_variables_removes_each_loop_node() -> None:
|
||||||
|
node = object.__new__(LoopNode)
|
||||||
|
remove_calls: list[list[str]] = []
|
||||||
|
node.graph_runtime_state = SimpleNamespace(
|
||||||
|
variable_pool=SimpleNamespace(remove=lambda selector: remove_calls.append(selector))
|
||||||
|
)
|
||||||
|
|
||||||
|
node._clear_loop_subgraph_variables({"child-a", "child-b"})
|
||||||
|
|
||||||
|
assert sorted(remove_calls) == [["child-a"], ["child-b"]]
|
||||||
|
|||||||
@ -8,11 +8,13 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
from dify_graph.file import File, FileTransferMethod, FileType
|
from dify_graph.file import File, FileTransferMethod, FileType
|
||||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent
|
from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||||
|
from dify_graph.nodes.tool.entities import ToolEntity as WorkflowToolEntity
|
||||||
|
from dify_graph.nodes.tool.entities import ToolNodeData
|
||||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||||
from dify_graph.system_variable import SystemVariable
|
from dify_graph.system_variable import SystemVariable
|
||||||
from dify_graph.variables.segments import ArrayFileSegment
|
from dify_graph.variables.segments import ArrayFileSegment
|
||||||
@ -167,3 +169,119 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode):
|
|||||||
files_segment = completed_events[0].node_run_result.outputs["files"]
|
files_segment = completed_events[0].node_run_result.outputs["files"]
|
||||||
assert isinstance(files_segment, ArrayFileSegment)
|
assert isinstance(files_segment, ArrayFileSegment)
|
||||||
assert files_segment.value == []
|
assert files_segment.value == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_workflow_tool_entity_accepts_primitives_and_tool_input_payloads() -> None:
|
||||||
|
entity = WorkflowToolEntity(
|
||||||
|
provider_id="provider",
|
||||||
|
provider_type="builtin",
|
||||||
|
provider_name="provider",
|
||||||
|
tool_name="search",
|
||||||
|
tool_label="Search",
|
||||||
|
tool_configurations={
|
||||||
|
"timeout": 30,
|
||||||
|
"query": {"type": "mixed", "value": "hello {{name}}"},
|
||||||
|
"selector": {"type": "variable", "value": ["start", "question"]},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entity.tool_configurations == {
|
||||||
|
"timeout": 30,
|
||||||
|
"query": {"type": "mixed", "value": "hello {{name}}"},
|
||||||
|
"selector": {"type": "variable", "value": ["start", "question"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_workflow_tool_entity_rejects_invalid_configuration_entries() -> None:
|
||||||
|
with pytest.raises(TypeError, match="Tool configuration values must be primitives"):
|
||||||
|
WorkflowToolEntity(
|
||||||
|
provider_id="provider",
|
||||||
|
provider_type="builtin",
|
||||||
|
provider_name="provider",
|
||||||
|
tool_name="search",
|
||||||
|
tool_label="Search",
|
||||||
|
tool_configurations={"bad": [object()]},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_node_data_filters_missing_tool_parameter_values() -> None:
|
||||||
|
node_data = ToolNodeData(
|
||||||
|
title="Tool",
|
||||||
|
provider_id="provider",
|
||||||
|
provider_type="builtin",
|
||||||
|
provider_name="provider",
|
||||||
|
tool_name="search",
|
||||||
|
tool_label="Search",
|
||||||
|
tool_configurations={},
|
||||||
|
tool_parameters={
|
||||||
|
"query": {"type": "mixed", "value": "hello"},
|
||||||
|
"skip_none": None,
|
||||||
|
"skip_empty": {"type": "constant", "value": None},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert set(node_data.tool_parameters.keys()) == {"query"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_parameters_reads_variables_and_optional_missing_inputs(tool_node: ToolNode) -> None:
|
||||||
|
variable_pool = MagicMock()
|
||||||
|
variable_pool.get.side_effect = [MagicMock(value="from-variable"), None]
|
||||||
|
node_data = ToolNodeData.model_validate(
|
||||||
|
{
|
||||||
|
"title": "Tool",
|
||||||
|
"provider_id": "provider",
|
||||||
|
"provider_type": "builtin",
|
||||||
|
"provider_name": "provider",
|
||||||
|
"tool_name": "tool",
|
||||||
|
"tool_label": "tool",
|
||||||
|
"tool_configurations": {},
|
||||||
|
"tool_parameters": {
|
||||||
|
"query": {"type": "variable", "value": ["start", "query"]},
|
||||||
|
"optional": {"type": "variable", "value": ["start", "optional"]},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool_parameters = [
|
||||||
|
ToolParameter.get_simple_instance("query", "query", ToolParameter.ToolParameterType.STRING, True),
|
||||||
|
ToolParameter.get_simple_instance("optional", "optional", ToolParameter.ToolParameterType.STRING, False),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = tool_node._generate_parameters(
|
||||||
|
tool_parameters=tool_parameters,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
node_data=node_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"query": "from-variable"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_parameters_formats_logs_and_unknown_parameters(tool_node: ToolNode) -> None:
|
||||||
|
variable_pool = MagicMock()
|
||||||
|
variable_pool.convert_template.return_value = MagicMock(text="rendered", log="masked")
|
||||||
|
node_data = ToolNodeData.model_validate(
|
||||||
|
{
|
||||||
|
"title": "Tool",
|
||||||
|
"provider_id": "provider",
|
||||||
|
"provider_type": "builtin",
|
||||||
|
"provider_name": "provider",
|
||||||
|
"tool_name": "tool",
|
||||||
|
"tool_label": "tool",
|
||||||
|
"tool_configurations": {},
|
||||||
|
"tool_parameters": {
|
||||||
|
"query": {"type": "mixed", "value": "{{ question }}"},
|
||||||
|
"missing": {"type": "constant", "value": "literal"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool_parameters = [
|
||||||
|
ToolParameter.get_simple_instance("query", "query", ToolParameter.ToolParameterType.STRING, True),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = tool_node._generate_parameters(
|
||||||
|
tool_parameters=tool_parameters,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
node_data=node_data,
|
||||||
|
for_log=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"query": "masked", "missing": None}
|
||||||
|
|||||||
@ -97,6 +97,22 @@ class TestWorkflowChildEngineBuilder:
|
|||||||
((sentinel.layer_two,), {}),
|
((sentinel.layer_two,), {}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_build_child_engine_tolerates_invalid_graph_shape_until_graph_init(self):
|
||||||
|
builder = workflow_entry._WorkflowChildEngineBuilder()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory),
|
||||||
|
patch.object(workflow_entry.Graph, "init", side_effect=ValueError("invalid graph")),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="invalid graph"):
|
||||||
|
builder.build_child_engine(
|
||||||
|
workflow_id="workflow-id",
|
||||||
|
graph_init_params=sentinel.graph_init_params,
|
||||||
|
graph_runtime_state=sentinel.graph_runtime_state,
|
||||||
|
graph_config={"nodes": "invalid"},
|
||||||
|
root_node_id="root",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestWorkflowEntryInit:
|
class TestWorkflowEntryInit:
|
||||||
def test_rejects_call_depth_above_limit(self):
|
def test_rejects_call_depth_above_limit(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user