mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts: # api/core/app/apps/advanced_chat/app_generator.py # api/core/app/apps/advanced_chat/app_runner.py # api/core/app/apps/advanced_chat/generate_task_pipeline.py # api/core/app/apps/base_app_runner.py # api/core/app/apps/workflow/app_runner.py # api/core/app/apps/workflow/generate_task_pipeline.py # api/core/app/task_pipeline/workflow_cycle_state_manager.py # api/core/workflow/entities/node_entities.py # api/core/workflow/nodes/llm/llm_node.py # api/core/workflow/workflow_engine_manager.py # api/tests/integration_tests/workflow/nodes/test_llm.py # api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py # api/tests/unit_tests/core/workflow/nodes/test_answer.py # api/tests/unit_tests/core/workflow/nodes/test_if_else.py # api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
This commit is contained in:
@ -93,6 +93,7 @@ class DatasetConfigManager:
|
||||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs["reranking_mode"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
@ -121,7 +121,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
@ -141,10 +141,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
@ -191,7 +191,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
@ -232,8 +232,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'user': user,
|
||||
'context': contextvars.copy_context()
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
@ -246,7 +245,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||
@ -259,7 +258,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
user: Account,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@ -307,14 +305,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False) \
|
||||
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
def _handle_advanced_chat_response(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
@ -334,7 +335,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -3,9 +3,6 @@ import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
@ -94,7 +91,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
|
||||
@ -44,7 +44,7 @@ from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
@ -69,14 +69,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
@ -102,7 +102,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
@ -127,7 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
|
||||
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
@ -239,7 +239,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
@ -270,9 +270,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
|
||||
@ -307,7 +307,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -316,7 +316,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -325,7 +325,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -334,10 +334,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -360,10 +360,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -400,7 +400,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
|
||||
# Save message
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
|
||||
@ -413,7 +413,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
@ -421,7 +421,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
@ -446,7 +446,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
@ -458,11 +458,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._message_end_to_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
# publish None when task finished
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
@ -507,7 +507,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata.copy()
|
||||
|
||||
|
||||
if 'annotation_reply' in extras['metadata']:
|
||||
del extras['metadata']['annotation_reply']
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
@ -126,7 +128,7 @@ class AppRunner:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
@ -366,7 +368,7 @@ class AppRunner:
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager
|
||||
)
|
||||
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
@ -418,7 +420,7 @@ class AppRunner:
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
|
||||
@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
return introduction
|
||||
|
||||
def _get_conversation(self, conversation_id: str) -> Conversation:
|
||||
def _get_conversation(self, conversation_id: str):
|
||||
"""
|
||||
Get conversation by conversation id
|
||||
:param conversation_id: conversation id
|
||||
@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_message(self, message_id: str) -> Message:
|
||||
|
||||
@ -11,7 +11,8 @@ from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
|
||||
@ -41,7 +41,9 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariable
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
@ -179,7 +181,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
@ -246,7 +248,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
|
||||
@ -281,7 +283,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -290,7 +292,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -299,7 +301,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@ -308,10 +310,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@ -332,10 +334,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
|
||||
@ -166,4 +166,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
@ -2,7 +2,6 @@ from .segment_group import SegmentGroup
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
@ -13,11 +12,9 @@ from .segments import (
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayAnyVariable,
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneVariable,
|
||||
@ -32,7 +29,6 @@ __all__ = [
|
||||
'FloatVariable',
|
||||
'ObjectVariable',
|
||||
'SecretVariable',
|
||||
'FileVariable',
|
||||
'StringVariable',
|
||||
'ArrayAnyVariable',
|
||||
'Variable',
|
||||
@ -45,11 +41,9 @@ __all__ = [
|
||||
'FloatSegment',
|
||||
'ObjectSegment',
|
||||
'ArrayAnySegment',
|
||||
'FileSegment',
|
||||
'StringSegment',
|
||||
'ArrayStringVariable',
|
||||
'ArrayNumberVariable',
|
||||
'ArrayObjectVariable',
|
||||
'ArrayFileVariable',
|
||||
'ArraySegment',
|
||||
]
|
||||
|
||||
@ -2,12 +2,10 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .exc import VariableError
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
@ -17,11 +15,9 @@ from .segments import (
|
||||
)
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
result = FloatVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise VariableError(f'invalid number value {value}')
|
||||
case SegmentType.FILE:
|
||||
result = FileVariable.model_validate(mapping)
|
||||
case SegmentType.OBJECT if isinstance(value, dict):
|
||||
result = ObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||
@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
result = ArrayNumberVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||
result = ArrayObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_FILE if isinstance(value, list):
|
||||
mapping = dict(mapping)
|
||||
mapping['value'] = [{'value': v} for v in value]
|
||||
result = ArrayFileVariable.model_validate(mapping)
|
||||
case _:
|
||||
raise VariableError(f'not supported value type {value_type}')
|
||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||
@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, list):
|
||||
return ArrayAnySegment(value=value)
|
||||
if isinstance(value, FileVar):
|
||||
return FileSegment(value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
|
||||
@ -5,8 +5,6 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
@ -78,14 +76,7 @@ class IntegerSegment(Segment):
|
||||
value: int
|
||||
|
||||
|
||||
class FileSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
# TODO: embed FileVar in this model.
|
||||
value: FileVar
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return self.value.to_markdown()
|
||||
|
||||
|
||||
class ObjectSegment(Segment):
|
||||
@ -108,7 +99,13 @@ class ObjectSegment(Segment):
|
||||
class ArraySegment(Segment):
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return '\n'.join(['- ' + item.markdown for item in self.value])
|
||||
items = []
|
||||
for item in self.value:
|
||||
if hasattr(item, 'to_markdown'):
|
||||
items.append(item.to_markdown())
|
||||
else:
|
||||
items.append(str(item))
|
||||
return '\n'.join(items)
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
@ -130,7 +127,3 @@ class ArrayObjectSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||
value: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
class ArrayFileSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||
value: Sequence[FileSegment]
|
||||
|
||||
@ -10,8 +10,6 @@ class SegmentType(str, Enum):
|
||||
ARRAY_STRING = 'array[string]'
|
||||
ARRAY_NUMBER = 'array[number]'
|
||||
ARRAY_OBJECT = 'array[object]'
|
||||
ARRAY_FILE = 'array[file]'
|
||||
OBJECT = 'object'
|
||||
FILE = 'file'
|
||||
|
||||
GROUP = 'group'
|
||||
|
||||
@ -4,11 +4,9 @@ from core.helper import encrypter
|
||||
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class FileVariable(FileSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectVariable(ObjectSegment, Variable):
|
||||
pass
|
||||
|
||||
@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayFileVariable(ArrayFileSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class SecretVariable(StringVariable):
|
||||
value_type: SegmentType = SegmentType.SECRET
|
||||
|
||||
Reference in New Issue
Block a user